SouzaAr commited on
Commit
dd830c9
·
verified ·
1 Parent(s): 0ac0eba

Upload 58 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. unsloth-main/unsloth-main/.github/FUNDING.yml +13 -0
  2. unsloth-main/unsloth-main/LICENSE +201 -0
  3. unsloth-main/unsloth-main/README.md +455 -0
  4. unsloth-main/unsloth-main/images/Assistant.png +0 -0
  5. unsloth-main/unsloth-main/images/Colab.png +0 -0
  6. unsloth-main/unsloth-main/images/Discord button.png +0 -0
  7. unsloth-main/unsloth-main/images/Discord.png +0 -0
  8. unsloth-main/unsloth-main/images/Free version button.png +0 -0
  9. unsloth-main/unsloth-main/images/Kaggle.png +0 -0
  10. unsloth-main/unsloth-main/images/Kofi button.png +0 -0
  11. unsloth-main/unsloth-main/images/LAION 2GPU.png +0 -0
  12. unsloth-main/unsloth-main/images/Merge.png +0 -0
  13. unsloth-main/unsloth-main/images/Run.png +0 -0
  14. unsloth-main/unsloth-main/images/Slim Orca 2GPUs.png +0 -0
  15. unsloth-main/unsloth-main/images/Terminal_Type.png +0 -0
  16. unsloth-main/unsloth-main/images/Where_Terminal.png +0 -0
  17. unsloth-main/unsloth-main/images/buy me a coffee button.png +0 -0
  18. unsloth-main/unsloth-main/images/made with unsloth.png +0 -0
  19. unsloth-main/unsloth-main/images/ollama.png +0 -0
  20. unsloth-main/unsloth-main/images/peft x trl button.png +0 -0
  21. unsloth-main/unsloth-main/images/start free finetune button.png +0 -0
  22. unsloth-main/unsloth-main/images/unsloth end.png +0 -0
  23. unsloth-main/unsloth-main/images/unsloth loading page render.png +0 -0
  24. unsloth-main/unsloth-main/images/unsloth logo black text.png +0 -0
  25. unsloth-main/unsloth-main/images/unsloth logo only.png +0 -0
  26. unsloth-main/unsloth-main/images/unsloth logo white text.png +0 -0
  27. unsloth-main/unsloth-main/images/unsloth made with love.png +0 -0
  28. unsloth-main/unsloth-main/images/unsloth new logo.png +0 -0
  29. unsloth-main/unsloth-main/pyproject.toml +327 -0
  30. unsloth-main/unsloth-main/unsloth-cli.py +221 -0
  31. unsloth-main/unsloth-main/unsloth/__init__.py +161 -0
  32. unsloth-main/unsloth-main/unsloth/_auto_install.py +30 -0
  33. unsloth-main/unsloth-main/unsloth/chat_templates.py +2210 -0
  34. unsloth-main/unsloth-main/unsloth/kernels/__init__.py +61 -0
  35. unsloth-main/unsloth-main/unsloth/kernels/cross_entropy_loss.py +461 -0
  36. unsloth-main/unsloth-main/unsloth/kernels/fast_lora.py +412 -0
  37. unsloth-main/unsloth-main/unsloth/kernels/flex_attention.py +180 -0
  38. unsloth-main/unsloth-main/unsloth/kernels/geglu.py +203 -0
  39. unsloth-main/unsloth-main/unsloth/kernels/layernorm.py +231 -0
  40. unsloth-main/unsloth-main/unsloth/kernels/rms_layernorm.py +283 -0
  41. unsloth-main/unsloth-main/unsloth/kernels/rope_embedding.py +181 -0
  42. unsloth-main/unsloth-main/unsloth/kernels/swiglu.py +99 -0
  43. unsloth-main/unsloth-main/unsloth/kernels/utils.py +416 -0
  44. unsloth-main/unsloth-main/unsloth/models/__init__.py +20 -0
  45. unsloth-main/unsloth-main/unsloth/models/_utils.py +1140 -0
  46. unsloth-main/unsloth-main/unsloth/models/cohere.py +473 -0
  47. unsloth-main/unsloth-main/unsloth/models/dpo.py +130 -0
  48. unsloth-main/unsloth-main/unsloth/models/gemma.py +430 -0
  49. unsloth-main/unsloth-main/unsloth/models/gemma2.py +581 -0
  50. unsloth-main/unsloth-main/unsloth/models/llama.py +0 -0
unsloth-main/unsloth-main/.github/FUNDING.yml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # These are supported funding model platforms
2
+
3
+ github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
4
+ patreon: # Replace with a single Patreon username
5
+ open_collective: # Replace with a single Open Collective username
6
+ ko_fi: unsloth
7
+ tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8
+ community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9
+ liberapay: # Replace with a single Liberapay username
10
+ issuehunt: # Replace with a single IssueHunt username
11
+ otechie: # Replace with a single Otechie username
12
+ lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
13
+ custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
unsloth-main/unsloth-main/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [2024-] [Unsloth AI, Daniel Han-Chen & Michael Han-Chen]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
unsloth-main/unsloth-main/README.md ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ <a href="https://unsloth.ai"><picture>
4
+ <source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20white%20text.png">
5
+ <source media="(prefers-color-scheme: light)" srcset="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20black%20text.png">
6
+ <img alt="unsloth logo" src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20black%20text.png" height="110" style="max-width: 100%;">
7
+ </picture></a>
8
+
9
+ <a href="https://colab.research.google.com/drive/1Ys44kVvmeZtnICzWz0xgpRnrIOjZAuxp?usp=sharing"><img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/start free finetune button.png" height="48"></a>
10
+ <a href="https://discord.gg/unsloth"><img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/Discord button.png" height="48"></a>
11
+ <a href="https://ko-fi.com/unsloth"><img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/buy me a coffee button.png" height="48"></a>
12
+
13
+ ### Finetune Llama 3.2, Mistral, Phi-3.5 & Gemma 2-5x faster with 80% less memory!
14
+
15
+ ![](https://i.ibb.co/sJ7RhGG/image-41.png)
16
+
17
+ </div>
18
+
19
+ ## ✨ Finetune for Free
20
+
21
+ All notebooks are **beginner friendly**! Add your dataset, click "Run All", and you'll get a 2x faster finetuned model which can be exported to GGUF, Ollama, vLLM or uploaded to Hugging Face.
22
+
23
+ | Unsloth supports | Free Notebooks | Performance | Memory use |
24
+ |-----------|---------|--------|----------|
25
+ | **Llama 3.2 (3B)** | [▶️ Start for free](https://colab.research.google.com/drive/1T5-zKWM_5OD21QHwXHiV9ixTRR7k3iB9?usp=sharing) | 2x faster | 60% less |
26
+ | **Llama 3.1 (8B)** | [▶️ Start for free](https://colab.research.google.com/drive/1Ys44kVvmeZtnICzWz0xgpRnrIOjZAuxp?usp=sharing) | 2x faster | 60% less |
27
+ | **Phi-3.5 (mini)** | [▶️ Start for free](https://colab.research.google.com/drive/1lN6hPQveB_mHSnTOYifygFcrO8C1bxq4?usp=sharing) | 2x faster | 50% less |
28
+ | **Gemma 2 (9B)** | [▶️ Start for free](https://colab.research.google.com/drive/1vIrqH5uYDQwsJ4-OO3DErvuv4pBgVwk4?usp=sharing) | 2x faster | 63% less |
29
+ | **Mistral Small (22B)** | [▶️ Start for free](https://colab.research.google.com/drive/1oCEHcED15DzL8xXGU1VTx5ZfOJM8WY01?usp=sharing) | 2x faster | 60% less |
30
+ | **Ollama** | [▶️ Start for free](https://colab.research.google.com/drive/1WZDi7APtQ9VsvOrQSSC5DDtxq159j8iZ?usp=sharing) | 1.9x faster | 43% less |
31
+ | **Mistral v0.3 (7B)** | [▶️ Start for free](https://colab.research.google.com/drive/1_yNCks4BTD5zOnjozppphh5GzMFaMKq_?usp=sharing) | 2.2x faster | 73% less |
32
+ | **ORPO** | [▶️ Start for free](https://colab.research.google.com/drive/11t4njE3c4Lxl-07OD8lJSMKkfyJml3Tn?usp=sharing) | 1.9x faster | 43% less |
33
+ | **DPO Zephyr** | [▶️ Start for free](https://colab.research.google.com/drive/15vttTpzzVXv_tJwEk-hIcQ0S9FcEWvwP?usp=sharing) | 1.9x faster | 43% less |
34
+
35
+ - **Kaggle Notebooks** for [Llama 3.1 (8B)](https://www.kaggle.com/danielhanchen/kaggle-llama-3-1-8b-unsloth-notebook), [Gemma 2 (9B)](https://www.kaggle.com/code/danielhanchen/kaggle-gemma-7b-unsloth-notebook/), [Mistral (7B)](https://www.kaggle.com/code/danielhanchen/kaggle-mistral-7b-unsloth-notebook)
36
+ - Run [Llama 3.2 1B 3B notebook](https://colab.research.google.com/drive/1hoHFpf7ROqk_oZHzxQdfPW9yvTxnvItq?usp=sharing) and [Llama 3.2 conversational notebook](https://colab.research.google.com/drive/1T5-zKWM_5OD21QHwXHiV9ixTRR7k3iB9?usp=sharing)
37
+ - Run [Llama 3.1 conversational notebook](https://colab.research.google.com/drive/15OyFkGoCImV9dSsewU1wa2JuKB4-mDE_?usp=sharing) and [Mistral v0.3 ChatML](https://colab.research.google.com/drive/15F1xyn8497_dUbxZP4zWmPZ3PJx1Oymv?usp=sharing)
38
+ - This [text completion notebook](https://colab.research.google.com/drive/1ef-tab5bhkvWmBOObepl1WgJvfvSzn5Q?usp=sharing) is for continued pretraining / raw text
39
+ - This [continued pretraining notebook](https://colab.research.google.com/drive/1tEd1FrOXWMnCU9UIvdYhs61tkxdMuKZu?usp=sharing) is for learning another language
40
+ - Click [here](https://github.com/unslothai/unsloth/wiki) for detailed documentation for Unsloth.
41
+
42
+ ## 🦥 Unsloth.ai News
43
+ - 📣 NEW! [Llama 3.2 Conversational notebook](https://colab.research.google.com/drive/1T5-zKWM_5OD21QHwXHiV9ixTRR7k3iB9?usp=sharing) includes training only on completions / outputs (increase accuracy), ShareGPT standardization and more!
44
+ - 📣 NEW! [Llama 3.2 Kaggle notebook](https://www.kaggle.com/danielhanchen/kaggle-llama-3-2-1b-3b-unsloth-notebook) and [Llama 3.2 Kaggle conversational notebook](https://www.kaggle.com/code/danielhanchen/kaggle-llama-3-2-1b-3b-conversational-unsloth/notebook)
45
+ - 📣 NEW! [Qwen 2.5 7b notebook](https://colab.research.google.com/drive/1Kose-ucXO1IBaZq5BvbwWieuubP7hxvQ?usp=sharing) finetuning is supported! Qwen 2.5 comes in multiple sizes - check our [4bit uploads](https://huggingface.co/unsloth) for 4x faster downloads!. 14b fits in a Colab GPU! [Qwen 2.5 conversational notebook](https://colab.research.google.com/drive/1qN1CEalC70EO1wGKhNxs1go1W9So61R5?usp=sharing)
46
+ - 📣 NEW! [Mistral Small 22b notebook](https://colab.research.google.com/drive/1oCEHcED15DzL8xXGU1VTx5ZfOJM8WY01?usp=sharing) finetuning fits in under 16GB of VRAM!
47
+ - 📣 NEW! [Phi-3.5 (mini)](https://colab.research.google.com/drive/1lN6hPQveB_mHSnTOYifygFcrO8C1bxq4?usp=sharing) now supported
48
+ - 📣 NEW! [Gemma-2-2b](https://colab.research.google.com/drive/1weTpKOjBZxZJ5PQ-Ql8i6ptAY2x-FWVA?usp=sharing) now supported! Try out [Chat interface](https://colab.research.google.com/drive/1i-8ESvtLRGNkkUQQr_-z_rcSAIo9c3lM?usp=sharing)!
49
+ - 📣 NEW! [Llama 3.1 8b, 70b](https://colab.research.google.com/drive/1Ys44kVvmeZtnICzWz0xgpRnrIOjZAuxp?usp=sharing) & [Mistral Nemo-12b](https://colab.research.google.com/drive/17d3U-CAIwzmbDRqbZ9NnpHxCkmXB6LZ0?usp=sharing) both Base and Instruct are now supported
50
+ <details>
51
+ <summary>Click for more news</summary>
52
+
53
+ - 📣 NEW! `pip install unsloth` now works! Head over to [pypi](https://pypi.org/project/unsloth/) to check it out! This allows non git pull installs. Use `pip install unsloth[colab-new]` for non dependency installs.
54
+ - 📣 NEW! [Gemma-2-9b](https://colab.research.google.com/drive/1vIrqH5uYDQwsJ4-OO3DErvuv4pBgVwk4?usp=sharing) and Gemma-2-27b now supported
55
+ - 📣 UPDATE! [Phi-3 mini](https://colab.research.google.com/drive/1hhdhBa1j_hsymiW9m-WzxQtgqTH_NHqi?usp=sharing) model updated. [Phi-3 Medium](https://colab.research.google.com/drive/1hhdhBa1j_hsymiW9m-WzxQtgqTH_NHqi?usp=sharing) 2x faster finetuning.
56
+ - 📣 NEW! Continued Pretraining [notebook](https://colab.research.google.com/drive/1tEd1FrOXWMnCU9UIvdYhs61tkxdMuKZu?usp=sharing) for other languages like Korean!
57
+ - 📣 NEW! Qwen2 now works
58
+ - 📣 [Mistral v0.3 Base](https://colab.research.google.com/drive/1_yNCks4BTD5zOnjozppphh5GzMFaMKq_?usp=sharing) and [Mistral v0.3 Instruct]
59
+ - 📣 [ORPO support](https://colab.research.google.com/drive/11t4njE3c4Lxl-07OD8lJSMKkfyJml3Tn?usp=sharing) is here + [2x faster inference](https://colab.research.google.com/drive/1aqlNQi7MMJbynFDyOQteD2t0yVfjb9Zh?usp=sharing) added for all our models
60
+ - 📣 We cut memory usage by a [further 30%](https://unsloth.ai/blog/long-context) and now support [4x longer context windows](https://unsloth.ai/blog/long-context)!
61
+ -
62
+ </details>
63
+
64
+ ## 🔗 Links and Resources
65
+ | Type | Links |
66
+ | ------------------------------- | --------------------------------------- |
67
+ | 📚 **Documentation & Wiki** | [Read Our Docs](https://docs.unsloth.ai) |
68
+ | <img height="14" src="https://upload.wikimedia.org/wikipedia/commons/6/6f/Logo_of_Twitter.svg" />&nbsp; **Twitter (aka X)** | [Follow us on X](https://twitter.com/unslothai)|
69
+ | 💾 **Installation** | [unsloth/README.md](https://github.com/unslothai/unsloth/tree/main#-installation-instructions)|
70
+ | 🥇 **Benchmarking** | [Performance Tables](https://github.com/unslothai/unsloth/tree/main#-performance-benchmarking)
71
+ | 🌐 **Released Models** | [Unsloth Releases](https://huggingface.co/unsloth)|
72
+ | ✍️ **Blog** | [Read our Blogs](https://unsloth.ai/blog)|
73
+
74
+ ## ⭐ Key Features
75
+ - All kernels written in [OpenAI's Triton](https://openai.com/research/triton) language. **Manual backprop engine**.
76
+ - **0% loss in accuracy** - no approximation methods - all exact.
77
+ - No change of hardware. Supports NVIDIA GPUs since 2018+. Minimum CUDA Capability 7.0 (V100, T4, Titan V, RTX 20, 30, 40x, A100, H100, L40 etc) [Check your GPU!](https://developer.nvidia.com/cuda-gpus) GTX 1070, 1080 works, but is slow.
78
+ - Works on **Linux** and **Windows** via WSL.
79
+ - Supports 4bit and 16bit QLoRA / LoRA finetuning via [bitsandbytes](https://github.com/TimDettmers/bitsandbytes).
80
+ - Open source trains 5x faster - see [Unsloth Pro](https://unsloth.ai/) for up to **30x faster training**!
81
+ - If you trained a model with 🦥Unsloth, you can use this cool sticker! &nbsp; <img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/made with unsloth.png" height="50" align="center" />
82
+
83
+
84
+ ## 🥇 Performance Benchmarking
85
+ - For the full list of **reproducible** benchmarking tables, [go to our website](https://unsloth.ai/blog/mistral-benchmark#Benchmark%20tables)
86
+
87
+ | 1 A100 40GB | 🤗Hugging Face | Flash Attention | 🦥Unsloth Open Source | 🦥[Unsloth Pro](https://unsloth.ai/pricing) |
88
+ |--------------|--------------|-----------------|---------------------|-----------------|
89
+ | Alpaca | 1x | 1.04x | 1.98x | **15.64x** |
90
+ | LAION Chip2 | 1x | 0.92x | 1.61x | **20.73x** |
91
+ | OASST | 1x | 1.19x | 2.17x | **14.83x** |
92
+ | Slim Orca | 1x | 1.18x | 2.22x | **14.82x** |
93
+
94
+ - Benchmarking table below was conducted by [🤗Hugging Face](https://huggingface.co/blog/unsloth-trl).
95
+
96
+ | Free Colab T4 | Dataset | 🤗Hugging Face | Pytorch 2.1.1 | 🦥Unsloth | 🦥 VRAM reduction |
97
+ | --- | --- | --- | --- | --- | --- |
98
+ | Llama-2 7b | OASST | 1x | 1.19x | 1.95x | -43.3% |
99
+ | Mistral 7b | Alpaca | 1x | 1.07x | 1.56x | -13.7% |
100
+ | Tiny Llama 1.1b | Alpaca | 1x | 2.06x | 3.87x | -73.8% |
101
+ | DPO with Zephyr | Ultra Chat | 1x | 1.09x | 1.55x | -18.6% |
102
+
103
+ ![](https://i.ibb.co/sJ7RhGG/image-41.png)
104
+
105
+ ## 💾 Installation Instructions
106
+
107
+ For stable releases, use `pip install unsloth`. We recommend `pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"` for most installations though.
108
+
109
+ ### Conda Installation
110
+ `⚠️Only use Conda if you have it. If not, use Pip`. Select either `pytorch-cuda=11.8,12.1` for CUDA 11.8 or CUDA 12.1. We support `python=3.10,3.11,3.12`.
111
+ ```bash
112
+ conda create --name unsloth_env \
113
+ python=3.11 \
114
+ pytorch-cuda=12.1 \
115
+ pytorch cudatoolkit xformers -c pytorch -c nvidia -c xformers \
116
+ -y
117
+ conda activate unsloth_env
118
+
119
+ pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
120
+ pip install --no-deps trl peft accelerate bitsandbytes
121
+ ```
122
+
123
+ <details>
124
+ <summary>If you're looking to install Conda in a Linux environment, <a href="https://docs.anaconda.com/miniconda/">read here</a>, or run the below 🔽</summary>
125
+
126
+ ```bash
127
+ mkdir -p ~/miniconda3
128
+ wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
129
+ bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
130
+ rm -rf ~/miniconda3/miniconda.sh
131
+ ~/miniconda3/bin/conda init bash
132
+ ~/miniconda3/bin/conda init zsh
133
+ ```
134
+ </details>
135
+
136
+ ### Pip Installation
137
+ `⚠️Do **NOT** use this if you have Conda.` Pip is a bit more complex since there are dependency issues. The pip command is different for `torch 2.2,2.3,2.4,2.5` and CUDA versions.
138
+
139
+ For other torch versions, we support `torch211`, `torch212`, `torch220`, `torch230`, `torch240` and for CUDA versions, we support `cu118` and `cu121`. For Ampere devices (A100, H100, RTX3090) and above, use `cu118-ampere` or `cu121-ampere`.
140
+
141
+ For example, if you have `torch 2.4` and `CUDA 12.1`, use:
142
+ ```bash
143
+ pip install --upgrade pip
144
+ pip install "unsloth[cu121-torch240] @ git+https://github.com/unslothai/unsloth.git"
145
+ ```
146
+
147
+ And other examples:
148
+ ```bash
149
+ pip install "unsloth[cu121-ampere-torch240] @ git+https://github.com/unslothai/unsloth.git"
150
+ pip install "unsloth[cu118-ampere-torch240] @ git+https://github.com/unslothai/unsloth.git"
151
+ pip install "unsloth[cu121-torch240] @ git+https://github.com/unslothai/unsloth.git"
152
+ pip install "unsloth[cu118-torch240] @ git+https://github.com/unslothai/unsloth.git"
153
+
154
+ pip install "unsloth[cu121-torch230] @ git+https://github.com/unslothai/unsloth.git"
155
+ pip install "unsloth[cu121-ampere-torch230] @ git+https://github.com/unslothai/unsloth.git"
156
+ ```
157
+
158
+ Or, run the below in a terminal to get the **optimal** pip installation command:
159
+ ```bash
160
+ wget -qO- https://raw.githubusercontent.com/unslothai/unsloth/main/unsloth/_auto_install.py | python -
161
+ ```
162
+
163
+ Or, run the below manually in a Python REPL:
164
+ ```python
165
+ try: import torch
166
+ except: raise ImportError("Install torch via `pip install torch`")
167
+ from packaging.version import Version as V
168
+ v = V(torch.__version__)
169
+ cuda = str(torch.version.cuda)
170
+ is_ampere = torch.cuda.get_device_capability()[0] >= 8
171
+ if cuda != "12.1" and cuda != "11.8": raise RuntimeError(f"CUDA = {cuda} not supported!")
172
+ if v <= V('2.1.0'): raise RuntimeError(f"Torch = {v} too old!")
173
+ elif v <= V('2.1.1'): x = 'cu{}{}-torch211'
174
+ elif v <= V('2.1.2'): x = 'cu{}{}-torch212'
175
+ elif v < V('2.3.0'): x = 'cu{}{}-torch220'
176
+ elif v < V('2.4.0'): x = 'cu{}{}-torch230'
177
+ elif v < V('2.5.0'): x = 'cu{}{}-torch240'
178
+ else: raise RuntimeError(f"Torch = {v} too new!")
179
+ x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
180
+ print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"')
181
+ ```
182
+
183
+ For **advanced installation instructions** or if you see weird errors during installations:
184
+
185
+ 1. Install `torch` and `triton`. Go to https://pytorch.org to install it. For example `pip install torch torchvision torchaudio triton`
186
+ 2. Confirm if CUDA is installated correctly. Try `nvcc`. If that fails, you need to install `cudatoolkit` or CUDA drivers.
187
+ 3. Install `xformers` manually. You can try installing `vllm` and seeing if `vllm` succeeds. Check if `xformers` succeeded with `python -m xformers.info` Go to https://github.com/facebookresearch/xformers. Another option is to install `flash-attn` for Ampere GPUs.
188
+ 4. Finally, install `bitsandbytes` and check it with `python -m bitsandbytes`
189
+
190
+ ## 📜 [Documentation](https://docs.unsloth.ai)
191
+ - Go to our official [Documentation](https://docs.unsloth.ai) for saving to GGUF, checkpointing, evaluation and more!
192
+ - We support Huggingface's TRL, Trainer, Seq2SeqTrainer or even Pytorch code!
193
+ - We're in 🤗Hugging Face's official docs! Check out the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth)!
194
+
195
+ ```python
196
+ from unsloth import FastLanguageModel
197
+ from unsloth import is_bfloat16_supported
198
+ import torch
199
+ from trl import SFTTrainer
200
+ from transformers import TrainingArguments
201
+ from datasets import load_dataset
202
+ max_seq_length = 2048 # Supports RoPE Scaling interally, so choose any!
203
+ # Get LAION dataset
204
+ url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
205
+ dataset = load_dataset("json", data_files = {"train" : url}, split = "train")
206
+
207
+ # 4bit pre quantized models we support for 4x faster downloading + no OOMs.
208
+ fourbit_models = [
209
+ "unsloth/mistral-7b-v0.3-bnb-4bit", # New Mistral v3 2x faster!
210
+ "unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
211
+ "unsloth/llama-3-8b-bnb-4bit", # Llama-3 15 trillion tokens model 2x faster!
212
+ "unsloth/llama-3-8b-Instruct-bnb-4bit",
213
+ "unsloth/llama-3-70b-bnb-4bit",
214
+ "unsloth/Phi-3-mini-4k-instruct", # Phi-3 2x faster!
215
+ "unsloth/Phi-3-medium-4k-instruct",
216
+ "unsloth/mistral-7b-bnb-4bit",
217
+ "unsloth/gemma-7b-bnb-4bit", # Gemma 2.2x faster!
218
+ ] # More models at https://huggingface.co/unsloth
219
+
220
+ model, tokenizer = FastLanguageModel.from_pretrained(
221
+ model_name = "unsloth/llama-3-8b-bnb-4bit",
222
+ max_seq_length = max_seq_length,
223
+ dtype = None,
224
+ load_in_4bit = True,
225
+ )
226
+
227
+ # Do model patching and add fast LoRA weights
228
+ model = FastLanguageModel.get_peft_model(
229
+ model,
230
+ r = 16,
231
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
232
+ "gate_proj", "up_proj", "down_proj",],
233
+ lora_alpha = 16,
234
+ lora_dropout = 0, # Supports any, but = 0 is optimized
235
+ bias = "none", # Supports any, but = "none" is optimized
236
+ # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
237
+ use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
238
+ random_state = 3407,
239
+ max_seq_length = max_seq_length,
240
+ use_rslora = False, # We support rank stabilized LoRA
241
+ loftq_config = None, # And LoftQ
242
+ )
243
+
244
+ trainer = SFTTrainer(
245
+ model = model,
246
+ train_dataset = dataset,
247
+ dataset_text_field = "text",
248
+ max_seq_length = max_seq_length,
249
+ tokenizer = tokenizer,
250
+ args = TrainingArguments(
251
+ per_device_train_batch_size = 2,
252
+ gradient_accumulation_steps = 4,
253
+ warmup_steps = 10,
254
+ max_steps = 60,
255
+ fp16 = not is_bfloat16_supported(),
256
+ bf16 = is_bfloat16_supported(),
257
+ logging_steps = 1,
258
+ output_dir = "outputs",
259
+ optim = "adamw_8bit",
260
+ seed = 3407,
261
+ ),
262
+ )
263
+ trainer.train()
264
+
265
+ # Go to https://github.com/unslothai/unsloth/wiki for advanced tips like
266
+ # (1) Saving to GGUF / merging to 16bit for vLLM
267
+ # (2) Continued training from a saved LoRA adapter
268
+ # (3) Adding an evaluation loop / OOMs
269
+ # (4) Customized chat templates
270
+ ```
271
+
272
+ <a name="DPO"></a>
273
+ ## DPO Support
274
+ DPO (Direct Preference Optimization), PPO, Reward Modelling all seem to work as per 3rd party independent testing from [Llama-Factory](https://github.com/hiyouga/LLaMA-Factory). We have a preliminary Google Colab notebook for reproducing Zephyr on Tesla T4 here: [notebook](https://colab.research.google.com/drive/15vttTpzzVXv_tJwEk-hIcQ0S9FcEWvwP?usp=sharing).
275
+
276
+ We're in 🤗Hugging Face's official docs! We're on the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and the [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth)!
277
+
278
+ ```python
279
+ from unsloth import FastLanguageModel, PatchDPOTrainer
280
+ from unsloth import is_bfloat16_supported
281
+ PatchDPOTrainer()
282
+ import torch
283
+ from transformers import TrainingArguments
284
+ from trl import DPOTrainer
285
+
286
+ model, tokenizer = FastLanguageModel.from_pretrained(
287
+ model_name = "unsloth/zephyr-sft-bnb-4bit",
288
+ max_seq_length = max_seq_length,
289
+ dtype = None,
290
+ load_in_4bit = True,
291
+ )
292
+
293
+ # Do model patching and add fast LoRA weights
294
+ model = FastLanguageModel.get_peft_model(
295
+ model,
296
+ r = 64,
297
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
298
+ "gate_proj", "up_proj", "down_proj",],
299
+ lora_alpha = 64,
300
+ lora_dropout = 0, # Supports any, but = 0 is optimized
301
+ bias = "none", # Supports any, but = "none" is optimized
302
+ # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
303
+ use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
304
+ random_state = 3407,
305
+ max_seq_length = max_seq_length,
306
+ )
307
+
308
+ dpo_trainer = DPOTrainer(
309
+ model = model,
310
+ ref_model = None,
311
+ args = TrainingArguments(
312
+ per_device_train_batch_size = 4,
313
+ gradient_accumulation_steps = 8,
314
+ warmup_ratio = 0.1,
315
+ num_train_epochs = 3,
316
+ fp16 = not is_bfloat16_supported(),
317
+ bf16 = is_bfloat16_supported(),
318
+ logging_steps = 1,
319
+ optim = "adamw_8bit",
320
+ seed = 42,
321
+ output_dir = "outputs",
322
+ ),
323
+ beta = 0.1,
324
+ train_dataset = YOUR_DATASET_HERE,
325
+ # eval_dataset = YOUR_DATASET_HERE,
326
+ tokenizer = tokenizer,
327
+ max_length = 1024,
328
+ max_prompt_length = 512,
329
+ )
330
+ dpo_trainer.train()
331
+ ```
332
+
333
+ ## 🥇 Detailed Benchmarking Tables
334
+ - Click "Code" for fully reproducible examples
335
+ - "Unsloth Equal" is a preview of our PRO version, with code stripped out. All settings and the loss curve remains identical.
336
+ - For the full list of benchmarking tables, [go to our website](https://unsloth.ai/blog/mistral-benchmark#Benchmark%20tables)
337
+
338
+ | 1 A100 40GB | 🤗Hugging Face | Flash Attention 2 | 🦥Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max |
339
+ |--------------|-------------|-------------|-----------------|--------------|---------------|-------------|
340
+ | Alpaca | 1x | 1.04x | 1.98x | 2.48x | 5.32x | **15.64x** |
341
+ | code | [Code](https://colab.research.google.com/drive/1u4dBeM-0vGNVmmO6X7cScAut-Hyt4KDF?usp=sharing) | [Code](https://colab.research.google.com/drive/1fgTOxpMbVjloQBvZyz4lF4BacKSZOB2A?usp=sharing) | [Code](https://colab.research.google.com/drive/1YIPY_18xm-K0iJDgvNkRoJsgkPMPAO3G?usp=sharing) | [Code](https://colab.research.google.com/drive/1ANW8EFL3LVyTD7Gq4TkheC1Z7Rxw-rHp?usp=sharing) | | |
342
+ | seconds| 1040 | 1001 | 525 | 419 | 196 | 67 |
343
+ | memory MB| 18235 | 15365 | 9631 | 8525 | | |
344
+ | % saved| | 15.74 | 47.18 | 53.25 | | | |
345
+
346
+ ### Llama-Factory 3rd party benchmarking
347
+ - [Link to performance table.](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-Comparison) TGS: tokens per GPU per second. Model: LLaMA2-7B. GPU: NVIDIA A100 * 1. Batch size: 4. Gradient accumulation: 2. LoRA rank: 8. Max length: 1024.
348
+
349
+ | Method | Bits | TGS | GRAM | Speed |
350
+ | --- | --- | --- | --- | --- |
351
+ | HF | 16 | 2392 | 18GB | 100% |
352
+ | HF+FA2 | 16 | 2954 | 17GB | 123% |
353
+ | Unsloth+FA2 | 16 | 4007 | 16GB | **168%** |
354
+ | HF | 4 | 2415 | 9GB | 101% |
355
+ | Unsloth+FA2 | 4 | 3726 | 7GB | **160%** |
356
+
357
+ ### Performance comparisons between popular models
358
+ <details>
359
+ <summary>Click for specific model benchmarking tables (Mistral 7b, CodeLlama 34b etc.)</summary>
360
+
361
+ ### Mistral 7b
362
+ | 1 A100 40GB | Hugging Face | Flash Attention 2 | Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max |
363
+ |--------------|-------------|-------------|-----------------|--------------|---------------|-------------|
364
+ | Mistral 7B Slim Orca | 1x | 1.15x | 2.15x | 2.53x | 4.61x | **13.69x** |
365
+ | code | [Code](https://colab.research.google.com/drive/1mePk3KzwTD81hr5mcNcs_AX3Kbg_Ha0x?usp=sharing) | [Code](https://colab.research.google.com/drive/1dgHxjvTmX6hb0bPcLp26RXSE6_n9DKj7?usp=sharing) | [Code](https://colab.research.google.com/drive/1SKrKGV-BZoU4kv5q3g0jtE_OhRgPtrrQ?usp=sharing) | [Code](https://colab.research.google.com/drive/18yOiyX0T81mTwZqOALFSCX_tSAqju6aD?usp=sharing) | |
366
+ | seconds | 1813 | 1571 | 842 | 718 | 393 | 132 |
367
+ | memory MB | 32853 | 19385 | 12465 | 10271 | | |
368
+ | % saved| | 40.99 | 62.06 | 68.74 | | |
369
+
370
+ ### CodeLlama 34b
371
+ | 1 A100 40GB | Hugging Face | Flash Attention 2 | Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max |
372
+ |--------------|-------------|-------------|-----------------|--------------|---------------|-------------|
373
+ | Code Llama 34B | OOM ❌ | 0.99x | 1.87x | 2.61x | 4.27x | 12.82x |
374
+ | code | [▶️ Code](https://colab.research.google.com/drive/1ykfz3BqrtC_AUFegCzUQjjfUNlxp6Otc?usp=sharing) | [Code](https://colab.research.google.com/drive/12ZypxQh7OC6kBXvWZI-5d05I4m-B_hoR?usp=sharing) | [Code](https://colab.research.google.com/drive/1gdHyAx8XJsz2yNV-DHvbHjR1iCef5Qmh?usp=sharing) | [Code](https://colab.research.google.com/drive/1fm7wqx9MJ0kRrwKOfmLkK1Rmw-pySahB?usp=sharing) | |
375
+ | seconds | 1953 | 1982 | 1043 | 748 | 458 | 152 |
376
+ | memory MB | 40000 | 33217 | 27413 | 22161 | | |
377
+ | % saved| | 16.96| 31.47 | 44.60 | | | |
378
+
379
+ ### 1 Tesla T4
380
+
381
+ | 1 T4 16GB | Hugging Face | Flash Attention | Unsloth Open | Unsloth Pro Equal | Unsloth Pro | Unsloth Max |
382
+ |--------------|-------------|-----------------|-----------------|---------------|---------------|-------------|
383
+ | Alpaca | 1x | 1.09x | 1.69x | 1.79x | 2.93x | **8.3x** |
384
+ | code | [▶️ Code](https://colab.research.google.com/drive/1XpLIV4s8Bj5uryB-X2gqM88oRGHEGdaB?usp=sharing) | [Code](https://colab.research.google.com/drive/1LyXu6CjuymQg6ddHX8g1dpUvrMa1nn4L?usp=sharing) | [Code](https://colab.research.google.com/drive/1gsv4LpY7C32otl1rgRo5wXTk4HIitXoM?usp=sharing) | [Code](https://colab.research.google.com/drive/1VtULwRQwhEnVdNryjm27zXfdSM1tNfFK?usp=sharing) | | |
385
+ | seconds | 1599 | 1468 | 942 | 894 | 545 | 193 |
386
+ | memory MB | 7199 | 7059 | 6459 | 5443 | | |
387
+ | % saved | | 1.94 | 10.28 | 24.39 | | |
388
+
389
+ ### 2 Tesla T4s via DDP
390
+
391
+ | 2 T4 DDP | Hugging Face | Flash Attention | Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max |
392
+ |--------------|----------|-------------|-----------------|--------------|---------------|-------------|
393
+ | Alpaca | 1x | 0.99x | 4.95x | 4.44x | 7.28x | **20.61x** |
394
+ | code | [▶️ Code](https://www.kaggle.com/danielhanchen/hf-original-alpaca-t4-ddp) | [Code](https://www.kaggle.com/danielhanchen/hf-sdpa-alpaca-t4-ddp) | [Code](https://www.kaggle.com/danielhanchen/unsloth-alpaca-t4-ddp) | | |
395
+ | seconds | 9882 | 9946 | 1996 | 2227 | 1357 | 480 |
396
+ | memory MB| 9176 | 9128 | 6904 | 6782 | | |
397
+ | % saved | | 0.52 | 24.76 | 26.09 | | | |
398
+ </details>
399
+
400
+ ### Performance comparisons on 1 Tesla T4 GPU:
401
+ <details>
402
+ <summary>Click for Time taken for 1 epoch</summary>
403
+
404
+ One Tesla T4 on Google Colab
405
+ `bsz = 2, ga = 4, max_grad_norm = 0.3, num_train_epochs = 1, seed = 3047, lr = 2e-4, wd = 0.01, optim = "adamw_8bit", schedule = "linear", schedule_steps = 10`
406
+
407
+ | System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) |
408
+ | --- | --- | --- | --- | --- | --- |
409
+ | Huggingface | 1 T4 | 23h 15m | 56h 28m | 8h 38m | 391h 41m |
410
+ | Unsloth Open | 1 T4 | 13h 7m (1.8x) | 31h 47m (1.8x) | 4h 27m (1.9x) | 240h 4m (1.6x) |
411
+ | Unsloth Pro | 1 T4 | 3h 6m (7.5x) | 5h 17m (10.7x) | 1h 7m (7.7x) | 59h 53m (6.5x) |
412
+ | Unsloth Max | 1 T4 | 2h 39m (8.8x) | 4h 31m (12.5x) | 0h 58m (8.9x) | 51h 30m (7.6x) |
413
+
414
+ **Peak Memory Usage**
415
+
416
+ | System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) |
417
+ | --- | --- | --- | --- | --- | --- |
418
+ | Huggingface | 1 T4 | 7.3GB | 5.9GB | 14.0GB | 13.3GB |
419
+ | Unsloth Open | 1 T4 | 6.8GB | 5.7GB | 7.8GB | 7.7GB |
420
+ | Unsloth Pro | 1 T4 | 6.4GB | 6.4GB | 6.4GB | 6.4GB |
421
+ | Unsloth Max | 1 T4 | 11.4GB | 12.4GB | 11.9GB | 14.4GB |
422
+ </details>
423
+
424
+ <details>
425
+ <summary>Click for Performance Comparisons on 2 Tesla T4 GPUs via DDP:</summary>
426
+ **Time taken for 1 epoch**
427
+
428
+ Two Tesla T4s on Kaggle
429
+ `bsz = 2, ga = 4, max_grad_norm = 0.3, num_train_epochs = 1, seed = 3047, lr = 2e-4, wd = 0.01, optim = "adamw_8bit", schedule = "linear", schedule_steps = 10`
430
+
431
+ | System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) * |
432
+ | --- | --- | --- | --- | --- | --- |
433
+ | Huggingface | 2 T4 | 84h 47m | 163h 48m | 30h 51m | 1301h 24m * |
434
+ | Unsloth Pro | 2 T4 | 3h 20m (25.4x) | 5h 43m (28.7x) | 1h 12m (25.7x) | 71h 40m (18.1x) * |
435
+ | Unsloth Max | 2 T4 | 3h 4m (27.6x) | 5h 14m (31.3x) | 1h 6m (28.1x) | 54h 20m (23.9x) * |
436
+
437
+ **Peak Memory Usage on a Multi GPU System (2 GPUs)**
438
+
439
+ | System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) * |
440
+ | --- | --- | --- | --- | --- | --- |
441
+ | Huggingface | 2 T4 | 8.4GB \| 6GB | 7.2GB \| 5.3GB | 14.3GB \| 6.6GB | 10.9GB \| 5.9GB * |
442
+ | Unsloth Pro | 2 T4 | 7.7GB \| 4.9GB | 7.5GB \| 4.9GB | 8.5GB \| 4.9GB | 6.2GB \| 4.7GB * |
443
+ | Unsloth Max | 2 T4 | 10.5GB \| 5GB | 10.6GB \| 5GB | 10.6GB \| 5GB | 10.5GB \| 5GB * |
444
+
445
+ * Slim Orca `bsz=1` for all benchmarks since `bsz=2` OOMs. We can handle `bsz=2`, but we benchmark it with `bsz=1` for consistency.
446
+ </details>
447
+
448
+ ![](https://i.ibb.co/sJ7RhGG/image-41.png)
449
+ <br>
450
+
451
+ ### Thank You to
452
+ - [HuyNguyen-hust](https://github.com/HuyNguyen-hust) for making [RoPE Embeddings 28% faster](https://github.com/unslothai/unsloth/pull/238)
453
+ - [RandomInternetPreson](https://github.com/RandomInternetPreson) for confirming WSL support
454
+ - [152334H](https://github.com/152334H) for experimental DPO support
455
+ - [atgctg](https://github.com/atgctg) for syntax highlighting
unsloth-main/unsloth-main/images/Assistant.png ADDED
unsloth-main/unsloth-main/images/Colab.png ADDED
unsloth-main/unsloth-main/images/Discord button.png ADDED
unsloth-main/unsloth-main/images/Discord.png ADDED
unsloth-main/unsloth-main/images/Free version button.png ADDED
unsloth-main/unsloth-main/images/Kaggle.png ADDED
unsloth-main/unsloth-main/images/Kofi button.png ADDED
unsloth-main/unsloth-main/images/LAION 2GPU.png ADDED
unsloth-main/unsloth-main/images/Merge.png ADDED
unsloth-main/unsloth-main/images/Run.png ADDED
unsloth-main/unsloth-main/images/Slim Orca 2GPUs.png ADDED
unsloth-main/unsloth-main/images/Terminal_Type.png ADDED
unsloth-main/unsloth-main/images/Where_Terminal.png ADDED
unsloth-main/unsloth-main/images/buy me a coffee button.png ADDED
unsloth-main/unsloth-main/images/made with unsloth.png ADDED
unsloth-main/unsloth-main/images/ollama.png ADDED
unsloth-main/unsloth-main/images/peft x trl button.png ADDED
unsloth-main/unsloth-main/images/start free finetune button.png ADDED
unsloth-main/unsloth-main/images/unsloth end.png ADDED
unsloth-main/unsloth-main/images/unsloth loading page render.png ADDED
unsloth-main/unsloth-main/images/unsloth logo black text.png ADDED
unsloth-main/unsloth-main/images/unsloth logo only.png ADDED
unsloth-main/unsloth-main/images/unsloth logo white text.png ADDED
unsloth-main/unsloth-main/images/unsloth made with love.png ADDED
unsloth-main/unsloth-main/images/unsloth new logo.png ADDED
unsloth-main/unsloth-main/pyproject.toml ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools", "setuptools-scm"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "unsloth"
7
+ dynamic = ["version"]
8
+ description = "2-5X faster LLM finetuning"
9
+ readme = "README.md"
10
+ requires-python = ">=3.9"
11
+ license = {file = "LICENSE"}
12
+ keywords = ["ai", "llm",]
13
+ authors = [
14
+ {email = "[email protected]"},
15
+ {name = "Unsloth AI team"},
16
+ ]
17
+ maintainers = [
18
+ {name = "Daniel Han", email = "[email protected]"},
19
+ {name = "Michael Han", email = "[email protected]"},
20
+ ]
21
+ classifiers = [
22
+ "Programming Language :: Python",
23
+ ]
24
+
25
+ [tool.setuptools.dynamic]
26
+ version = {attr = "unsloth.models._utils.__version__"}
27
+
28
+ [tool.setuptools]
29
+ include-package-data = false
30
+
31
+ [tool.setuptools.packages.find]
32
+ exclude = ["images*"]
33
+
34
+ [project.optional-dependencies]
35
+ huggingface = [
36
+ "packaging",
37
+ "tyro",
38
+ "transformers>=4.44.2",
39
+ "datasets>=2.16.0",
40
+ "sentencepiece>=0.2.0",
41
+ "tqdm",
42
+ "psutil",
43
+ "wheel>=0.42.0",
44
+ "numpy",
45
+ "accelerate>=0.34.1",
46
+ "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,<=0.11.1",
47
+ "peft>=0.7.1,!=0.11.0",
48
+ "protobuf<4.0.0",
49
+ "huggingface_hub",
50
+ "hf_transfer",
51
+ ]
52
+ cu118only = [
53
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
54
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
55
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
56
+ ]
57
+ cu121only = [
58
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
59
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
60
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
61
+ ]
62
+ cu118onlytorch211 = [
63
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
64
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
65
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
66
+ ]
67
+ cu121onlytorch211 = [
68
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
69
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
70
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
71
+ ]
72
+ cu118onlytorch212 = [
73
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
74
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
75
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
76
+ ]
77
+ cu121onlytorch212 = [
78
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23.post1-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
79
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23.post1-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
80
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23.post1-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
81
+ ]
82
+ cu118onlytorch220 = [
83
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
84
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
85
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
86
+ ]
87
+ cu121onlytorch220 = [
88
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
89
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
90
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
91
+ ]
92
+ cu118onlytorch230 = [
93
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
94
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
95
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
96
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12'",
97
+ ]
98
+ cu121onlytorch230 = [
99
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
100
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
101
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
102
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12'",
103
+ ]
104
+ cu118onlytorch240 = [
105
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
106
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
107
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
108
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12'",
109
+ ]
110
+ cu121onlytorch240 = [
111
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27.post2-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
112
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27.post2-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
113
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27.post2-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
114
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27.post2-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12'",
115
+ ]
116
+ cu118 = [
117
+ "unsloth[huggingface]",
118
+ "bitsandbytes>=0.43.3",
119
+ "unsloth[cu118only]",
120
+ ]
121
+ cu121 = [
122
+ "unsloth[huggingface]",
123
+ "bitsandbytes>=0.43.3",
124
+ "unsloth[cu121only]",
125
+ ]
126
+ cu118-torch211 = [
127
+ "unsloth[huggingface]",
128
+ "bitsandbytes>=0.43.3",
129
+ "unsloth[cu118onlytorch211]",
130
+ ]
131
+ cu121-torch211 = [
132
+ "unsloth[huggingface]",
133
+ "bitsandbytes>=0.43.3",
134
+ "unsloth[cu121onlytorch211]",
135
+ ]
136
+ cu118-torch212 = [
137
+ "unsloth[huggingface]",
138
+ "bitsandbytes>=0.43.3",
139
+ "unsloth[cu118onlytorch212]",
140
+ ]
141
+ cu121-torch212 = [
142
+ "unsloth[huggingface]",
143
+ "bitsandbytes>=0.43.3",
144
+ "unsloth[cu121onlytorch212]",
145
+ ]
146
+ cu118-torch220 = [
147
+ "unsloth[huggingface]",
148
+ "bitsandbytes>=0.43.3",
149
+ "unsloth[cu118onlytorch220]",
150
+ ]
151
+ cu121-torch220 = [
152
+ "unsloth[huggingface]",
153
+ "bitsandbytes>=0.43.3",
154
+ "unsloth[cu121onlytorch220]",
155
+ ]
156
+ cu118-torch230 = [
157
+ "unsloth[huggingface]",
158
+ "bitsandbytes>=0.43.3",
159
+ "unsloth[cu118onlytorch230]",
160
+ ]
161
+ cu121-torch230 = [
162
+ "unsloth[huggingface]",
163
+ "bitsandbytes>=0.43.3",
164
+ "unsloth[cu121onlytorch230]",
165
+ ]
166
+ cu118-torch240 = [
167
+ "unsloth[huggingface]",
168
+ "bitsandbytes>=0.43.3",
169
+ "unsloth[cu118onlytorch240]",
170
+ ]
171
+ cu121-torch240 = [
172
+ "unsloth[huggingface]",
173
+ "bitsandbytes>=0.43.3",
174
+ "unsloth[cu121onlytorch240]",
175
+ ]
176
+ kaggle = [
177
+ "unsloth[huggingface]",
178
+ ]
179
+ kaggle-new = [
180
+ "unsloth[huggingface]",
181
+ "bitsandbytes>=0.43.3",
182
+ ]
183
+ conda = [
184
+ "unsloth[huggingface]",
185
+ ]
186
+ colab-torch211 = [
187
+ "unsloth[huggingface]",
188
+ "bitsandbytes>=0.43.3",
189
+ "unsloth[cu121onlytorch211]",
190
+ ]
191
+ colab-ampere-torch211 = [
192
+ "unsloth[huggingface]",
193
+ "bitsandbytes>=0.43.3",
194
+ "unsloth[cu121onlytorch211]",
195
+ "packaging",
196
+ "ninja",
197
+ "flash-attn>=2.6.3",
198
+ ]
199
+ colab-torch220 = [
200
+ "unsloth[huggingface]",
201
+ "bitsandbytes>=0.43.3",
202
+ "unsloth[cu121onlytorch220]",
203
+ ]
204
+ colab-ampere-torch220 = [
205
+ "unsloth[huggingface]",
206
+ "bitsandbytes>=0.43.3",
207
+ "unsloth[cu121onlytorch220]",
208
+ "packaging",
209
+ "ninja",
210
+ "flash-attn>=2.6.3",
211
+ ]
212
+ colab-new = [
213
+ "packaging",
214
+ "tyro",
215
+ "transformers>=4.44.2",
216
+ "datasets>=2.16.0",
217
+ "sentencepiece>=0.2.0",
218
+ "tqdm",
219
+ "psutil",
220
+ "wheel>=0.42.0",
221
+ "numpy",
222
+ "protobuf<4.0.0",
223
+ "huggingface_hub",
224
+ "hf_transfer",
225
+ ]
226
+ colab-no-deps = [
227
+ "accelerate>=0.34.1",
228
+ "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,<=0.11.1",
229
+ "peft>=0.7.1",
230
+ "xformers<0.0.27",
231
+ "bitsandbytes>=0.43.3",
232
+ "protobuf<4.0.0",
233
+ ]
234
+ colab = [
235
+ "unsloth[cu121]",
236
+ ]
237
+ colab-ampere = [
238
+ "unsloth[colab-ampere-torch220]",
239
+ "packaging",
240
+ "ninja",
241
+ "flash-attn>=2.6.3",
242
+ ]
243
+ cu118-ampere = [
244
+ "unsloth[huggingface]",
245
+ "bitsandbytes>=0.43.3",
246
+ "unsloth[cu118only]",
247
+ "packaging",
248
+ "ninja",
249
+ "flash-attn>=2.6.3",
250
+ ]
251
+ cu121-ampere = [
252
+ "unsloth[huggingface]",
253
+ "bitsandbytes>=0.43.3",
254
+ "unsloth[cu121only]",
255
+ "packaging",
256
+ "ninja",
257
+ "flash-attn>=2.6.3",
258
+ ]
259
+ cu118-ampere-torch211 = [
260
+ "unsloth[huggingface]",
261
+ "bitsandbytes>=0.43.3",
262
+ "unsloth[cu118onlytorch211]",
263
+ "packaging",
264
+ "ninja",
265
+ "flash-attn>=2.6.3",
266
+ ]
267
+ cu121-ampere-torch211 = [
268
+ "unsloth[huggingface]",
269
+ "bitsandbytes>=0.43.3",
270
+ "unsloth[cu121onlytorch211]",
271
+ "packaging",
272
+ "ninja",
273
+ "flash-attn>=2.6.3",
274
+ ]
275
+ cu118-ampere-torch220 = [
276
+ "unsloth[huggingface]",
277
+ "bitsandbytes>=0.43.3",
278
+ "unsloth[cu118onlytorch220]",
279
+ "packaging",
280
+ "ninja",
281
+ "flash-attn>=2.6.3",
282
+ ]
283
+ cu121-ampere-torch220 = [
284
+ "unsloth[huggingface]",
285
+ "bitsandbytes>=0.43.3",
286
+ "unsloth[cu121onlytorch220]",
287
+ "packaging",
288
+ "ninja",
289
+ "flash-attn>=2.6.3",
290
+ ]
291
+ cu118-ampere-torch230 = [
292
+ "unsloth[huggingface]",
293
+ "bitsandbytes>=0.43.3",
294
+ "unsloth[cu118onlytorch230]",
295
+ "packaging",
296
+ "ninja",
297
+ "flash-attn>=2.6.3",
298
+ ]
299
+ cu121-ampere-torch230 = [
300
+ "unsloth[huggingface]",
301
+ "bitsandbytes>=0.43.3",
302
+ "unsloth[cu121onlytorch230]",
303
+ "packaging",
304
+ "ninja",
305
+ "flash-attn>=2.6.3",
306
+ ]
307
+ cu118-ampere-torch240 = [
308
+ "unsloth[huggingface]",
309
+ "bitsandbytes>=0.43.3",
310
+ "unsloth[cu118onlytorch240]",
311
+ "packaging",
312
+ "ninja",
313
+ "flash-attn>=2.6.3",
314
+ ]
315
+ cu121-ampere-torch240 = [
316
+ "unsloth[huggingface]",
317
+ "bitsandbytes>=0.43.3",
318
+ "unsloth[cu121onlytorch240]",
319
+ "packaging",
320
+ "ninja",
321
+ "flash-attn>=2.6.3",
322
+ ]
323
+
324
+ [project.urls]
325
+ homepage = "http://www.unsloth.ai"
326
+ documentation = "https://github.com/unslothai/unsloth"
327
+ repository = "https://github.com/unslothai/unsloth"
unsloth-main/unsloth-main/unsloth-cli.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ 🦥 Starter Script for Fine-Tuning FastLanguageModel with Unsloth
5
+
6
+ This script is designed as a starting point for fine-tuning your models using unsloth.
7
+ It includes configurable options for model loading, PEFT parameters, training arguments,
8
+ and model saving/pushing functionalities.
9
+
10
+ You will likely want to customize this script to suit your specific use case
11
+ and requirements.
12
+
13
+ Here are a few suggestions for customization:
14
+ - Modify the dataset loading and preprocessing steps to match your data.
15
+ - Customize the model saving and pushing configurations.
16
+
17
+ Usage: (most of the options have valid default values this is an extended example for demonstration purposes)
18
+ python unsloth-cli.py --model_name "unsloth/llama-3-8b" --max_seq_length 8192 --dtype None --load_in_4bit \
19
+ --r 64 --lora_alpha 32 --lora_dropout 0.1 --bias "none" --use_gradient_checkpointing "unsloth" \
20
+ --random_state 3407 --use_rslora --per_device_train_batch_size 4 --gradient_accumulation_steps 8 \
21
+ --warmup_steps 5 --max_steps 400 --learning_rate 2e-6 --logging_steps 1 --optim "adamw_8bit" \
22
+ --weight_decay 0.005 --lr_scheduler_type "linear" --seed 3407 --output_dir "outputs" \
23
+ --report_to "tensorboard" --save_model --save_path "model" --quantization_method "f16" \
24
+ --push_model --hub_path "hf/model" --hub_token "your_hf_token"
25
+
26
+ To see a full list of configurable options, use:
27
+ python unsloth-cli.py --help
28
+
29
+ Happy fine-tuning!
30
+ """
31
+
32
+ import argparse
33
+
34
+ def run(args):
35
+ import torch
36
+ from unsloth import FastLanguageModel
37
+ from datasets import load_dataset
38
+ from trl import SFTTrainer
39
+ from transformers import TrainingArguments
40
+ from unsloth import is_bfloat16_supported
41
+ import logging
42
+ logging.getLogger('hf-to-gguf').setLevel(logging.WARNING)
43
+
44
+ # Load model and tokenizer
45
+ model, tokenizer = FastLanguageModel.from_pretrained(
46
+ model_name=args.model_name,
47
+ max_seq_length=args.max_seq_length,
48
+ dtype=args.dtype,
49
+ load_in_4bit=args.load_in_4bit,
50
+ )
51
+
52
+ # Configure PEFT model
53
+ model = FastLanguageModel.get_peft_model(
54
+ model,
55
+ r=args.r,
56
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
57
+ "gate_proj", "up_proj", "down_proj"],
58
+ lora_alpha=args.lora_alpha,
59
+ lora_dropout=args.lora_dropout,
60
+ bias=args.bias,
61
+ use_gradient_checkpointing=args.use_gradient_checkpointing,
62
+ random_state=args.random_state,
63
+ use_rslora=args.use_rslora,
64
+ loftq_config=args.loftq_config,
65
+ )
66
+
67
+ alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
68
+
69
+ ### Instruction:
70
+ {}
71
+
72
+ ### Input:
73
+ {}
74
+
75
+ ### Response:
76
+ {}"""
77
+
78
+ EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
79
+ def formatting_prompts_func(examples):
80
+ instructions = examples["instruction"]
81
+ inputs = examples["input"]
82
+ outputs = examples["output"]
83
+ texts = []
84
+ for instruction, input, output in zip(instructions, inputs, outputs):
85
+ text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
86
+ texts.append(text)
87
+ return {"text": texts}
88
+
89
+ # Load and format dataset
90
+ dataset = load_dataset(args.dataset, split="train")
91
+ dataset = dataset.map(formatting_prompts_func, batched=True)
92
+ print("Data is formatted and ready!")
93
+
94
+ # Configure training arguments
95
+ training_args = TrainingArguments(
96
+ per_device_train_batch_size=args.per_device_train_batch_size,
97
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
98
+ warmup_steps=args.warmup_steps,
99
+ max_steps=args.max_steps,
100
+ learning_rate=args.learning_rate,
101
+ fp16=not is_bfloat16_supported(),
102
+ bf16=is_bfloat16_supported(),
103
+ logging_steps=args.logging_steps,
104
+ optim=args.optim,
105
+ weight_decay=args.weight_decay,
106
+ lr_scheduler_type=args.lr_scheduler_type,
107
+ seed=args.seed,
108
+ output_dir=args.output_dir,
109
+ report_to=args.report_to,
110
+ )
111
+
112
+ # Initialize trainer
113
+ trainer = SFTTrainer(
114
+ model=model,
115
+ tokenizer=tokenizer,
116
+ train_dataset=dataset,
117
+ dataset_text_field="text",
118
+ max_seq_length=args.max_seq_length,
119
+ dataset_num_proc=2,
120
+ packing=False,
121
+ args=training_args,
122
+ )
123
+
124
+ # Train model
125
+ trainer_stats = trainer.train()
126
+
127
+ # Save model
128
+ if args.save_model:
129
+ # if args.quantization_method is a list, we will save the model for each quantization method
130
+ if args.save_gguf:
131
+ if isinstance(args.quantization, list):
132
+ for quantization_method in args.quantization:
133
+ print(f"Saving model with quantization method: {quantization_method}")
134
+ model.save_pretrained_gguf(
135
+ args.save_path,
136
+ tokenizer,
137
+ quantization_method=quantization_method,
138
+ )
139
+ if args.push_model:
140
+ model.push_to_hub_gguf(
141
+ hub_path=args.hub_path,
142
+ hub_token=args.hub_token,
143
+ quantization_method=quantization_method,
144
+ )
145
+ else:
146
+ print(f"Saving model with quantization method: {args.quantization}")
147
+ model.save_pretrained_gguf(args.save_path, tokenizer, quantization_method=args.quantization)
148
+ if args.push_model:
149
+ model.push_to_hub_gguf(
150
+ hub_path=args.hub_path,
151
+ hub_token=args.hub_token,
152
+ quantization_method=quantization_method,
153
+ )
154
+ else:
155
+ model.save_pretrained_merged(args.save_path, tokenizer, args.save_method)
156
+ if args.push_model:
157
+ model.push_to_hub_merged(args.save_path, tokenizer, args.hub_token)
158
+ else:
159
+ print("Warning: The model is not saved!")
160
+
161
+
162
+ if __name__ == "__main__":
163
+
164
+ # Define argument parser
165
+ parser = argparse.ArgumentParser(description="🦥 Fine-tune your llm faster using unsloth!")
166
+
167
+ model_group = parser.add_argument_group("🤖 Model Options")
168
+ model_group.add_argument('--model_name', type=str, default="unsloth/llama-3-8b", help="Model name to load")
169
+ model_group.add_argument('--max_seq_length', type=int, default=2048, help="Maximum sequence length, default is 2048. We auto support RoPE Scaling internally!")
170
+ model_group.add_argument('--dtype', type=str, default=None, help="Data type for model (None for auto detection)")
171
+ model_group.add_argument('--load_in_4bit', action='store_true', help="Use 4bit quantization to reduce memory usage")
172
+ model_group.add_argument('--dataset', type=str, default="yahma/alpaca-cleaned", help="Huggingface dataset to use for training")
173
+
174
+ lora_group = parser.add_argument_group("🧠 LoRA Options", "These options are used to configure the LoRA model.")
175
+ lora_group.add_argument('--r', type=int, default=16, help="Rank for Lora model, default is 16. (common values: 8, 16, 32, 64, 128)")
176
+ lora_group.add_argument('--lora_alpha', type=int, default=16, help="LoRA alpha parameter, default is 16. (common values: 8, 16, 32, 64, 128)")
177
+ lora_group.add_argument('--lora_dropout', type=float, default=0, help="LoRA dropout rate, default is 0.0 which is optimized.")
178
+ lora_group.add_argument('--bias', type=str, default="none", help="Bias setting for LoRA")
179
+ lora_group.add_argument('--use_gradient_checkpointing', type=str, default="unsloth", help="Use gradient checkpointing")
180
+ lora_group.add_argument('--random_state', type=int, default=3407, help="Random state for reproducibility, default is 3407.")
181
+ lora_group.add_argument('--use_rslora', action='store_true', help="Use rank stabilized LoRA")
182
+ lora_group.add_argument('--loftq_config', type=str, default=None, help="Configuration for LoftQ")
183
+
184
+
185
+ training_group = parser.add_argument_group("🎓 Training Options")
186
+ training_group.add_argument('--per_device_train_batch_size', type=int, default=2, help="Batch size per device during training, default is 2.")
187
+ training_group.add_argument('--gradient_accumulation_steps', type=int, default=4, help="Number of gradient accumulation steps, default is 4.")
188
+ training_group.add_argument('--warmup_steps', type=int, default=5, help="Number of warmup steps, default is 5.")
189
+ training_group.add_argument('--max_steps', type=int, default=400, help="Maximum number of training steps.")
190
+ training_group.add_argument('--learning_rate', type=float, default=2e-4, help="Learning rate, default is 2e-4.")
191
+ training_group.add_argument('--optim', type=str, default="adamw_8bit", help="Optimizer type.")
192
+ training_group.add_argument('--weight_decay', type=float, default=0.01, help="Weight decay, default is 0.01.")
193
+ training_group.add_argument('--lr_scheduler_type', type=str, default="linear", help="Learning rate scheduler type, default is 'linear'.")
194
+ training_group.add_argument('--seed', type=int, default=3407, help="Seed for reproducibility, default is 3407.")
195
+
196
+
197
+ # Report/Logging arguments
198
+ report_group = parser.add_argument_group("📊 Report Options")
199
+ report_group.add_argument('--report_to', type=str, default="tensorboard",
200
+ choices=["azure_ml", "clearml", "codecarbon", "comet_ml", "dagshub", "dvclive", "flyte", "mlflow", "neptune", "tensorboard", "wandb", "all", "none"],
201
+ help="The list of integrations to report the results and logs to. Supported platforms are: \n\t\t 'azure_ml', 'clearml', 'codecarbon', 'comet_ml', 'dagshub', 'dvclive', 'flyte', 'mlflow', 'neptune', 'tensorboard', and 'wandb'. Use 'all' to report to all integrations installed, 'none' for no integrations.")
202
+ report_group.add_argument('--logging_steps', type=int, default=1, help="Logging steps, default is 1")
203
+
204
+ # Saving and pushing arguments
205
+ save_group = parser.add_argument_group('💾 Save Model Options')
206
+ save_group.add_argument('--output_dir', type=str, default="outputs", help="Output directory")
207
+ save_group.add_argument('--save_model', action='store_true', help="Save the model after training")
208
+ save_group.add_argument('--save_method', type=str, default="merged_16bit", choices=["merged_16bit", "merged_4bit", "lora"], help="Save method for the model, default is 'merged_16bit'")
209
+ save_group.add_argument('--save_gguf', action='store_true', help="Convert the model to GGUF after training")
210
+ save_group.add_argument('--save_path', type=str, default="model", help="Path to save the model")
211
+ save_group.add_argument('--quantization', type=str, default="q8_0", nargs="+",
212
+ help="Quantization method for saving the model. common values ('f16', 'q4_k_m', 'q8_0'), Check our wiki for all quantization methods https://github.com/unslothai/unsloth/wiki#saving-to-gguf ")
213
+
214
+ push_group = parser.add_argument_group('🚀 Push Model Options')
215
+ push_group.add_argument('--push_model', action='store_true', help="Push the model to Hugging Face hub after training")
216
+ push_group.add_argument('--push_gguf', action='store_true', help="Push the model as GGUF to Hugging Face hub after training")
217
+ push_group.add_argument('--hub_path', type=str, default="hf/model", help="Path on Hugging Face hub to push the model")
218
+ push_group.add_argument('--hub_token', type=str, help="Token for pushing the model to Hugging Face hub")
219
+
220
+ args = parser.parse_args()
221
+ run(args)
unsloth-main/unsloth-main/unsloth/__init__.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import warnings, importlib, sys
16
+ from packaging.version import Version
17
+ import os, re, subprocess, inspect
18
+ import numpy as np
19
+
20
+ # # Define a list of modules to check
21
+ # MODULES_TO_CHECK = ["bitsandbytes"]
22
+
23
+ # # Check if any of the modules in the list have been imported
24
+ # for module in MODULES_TO_CHECK:
25
+ # if module in sys.modules:
26
+ # raise ImportError(f"Unsloth: Please import Unsloth before {module}.")
27
+ # pass
28
+ # pass
29
+
30
+ # Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so
31
+ # enabling it will require much more work, so we have to prioritize. Please understand!
32
+ # We do have a beta version, which you can contact us about!
33
+ # Thank you for your understanding and we appreciate it immensely!
34
+ if "CUDA_VISIBLE_DEVICES" in os.environ:
35
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
36
+ devices = os.environ["CUDA_VISIBLE_DEVICES"]
37
+ # Check if there are multiple cuda devices set in env
38
+ if not devices.isdigit():
39
+ first_id = devices.split(",")[0]
40
+ warnings.warn(
41
+ f"Unsloth: 'CUDA_VISIBLE_DEVICES' is currently {devices} \n"\
42
+ "Unsloth currently does not support multi GPU setups - but we are working on it!\n"\
43
+ "Multiple CUDA devices detected but we require a single device.\n"\
44
+ f"We will override CUDA_VISIBLE_DEVICES to first device: {first_id}."
45
+ )
46
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(first_id)
47
+ else:
48
+ # warnings.warn("Unsloth: 'CUDA_VISIBLE_DEVICES' is not set. We shall set it ourselves.")
49
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
50
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
51
+ pass
52
+
53
+ # Reduce VRAM usage by reducing fragmentation
54
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
55
+
56
+ try:
57
+ import torch
58
+ except:
59
+ raise ImportError("Pytorch is not installed. Go to https://pytorch.org/.\n"\
60
+ "We have some installation instructions on our Github page.")
61
+ pass
62
+
63
+ # Hugging Face Hub faster downloads (only enable during Colab and Kaggle sessions)
64
+ keynames = "\n" + "\n".join(os.environ.keys())
65
+ if "\nCOLAB_" in keynames or "\nKAGGLE_" in keynames:
66
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
67
+ pass
68
+
69
+ # We support Pytorch 2
70
+ # Fixes https://github.com/unslothai/unsloth/issues/38
71
+ torch_version = torch.__version__.split(".")
72
+ major_torch, minor_torch = torch_version[0], torch_version[1]
73
+ major_torch, minor_torch = int(major_torch), int(minor_torch)
74
+ if (major_torch < 2):
75
+ raise ImportError("Unsloth only supports Pytorch 2 for now. Please update your Pytorch to 2.1.\n"\
76
+ "We have some installation instructions on our Github page.")
77
+ elif (major_torch == 2) and (minor_torch < 2):
78
+ # Disable expandable_segments
79
+ del os.environ["PYTORCH_CUDA_ALLOC_CONF"]
80
+ pass
81
+
82
+ # Torch 2.4 has including_emulation
83
+ major_version, minor_version = torch.cuda.get_device_capability()
84
+ SUPPORTS_BFLOAT16 = (major_version >= 8)
85
+
86
+ old_is_bf16_supported = torch.cuda.is_bf16_supported
87
+ if "including_emulation" in str(inspect.signature(old_is_bf16_supported)):
88
+ def is_bf16_supported(including_emulation = False):
89
+ return old_is_bf16_supported(including_emulation)
90
+ torch.cuda.is_bf16_supported = is_bf16_supported
91
+ else:
92
+ def is_bf16_supported(): return SUPPORTS_BFLOAT16
93
+ torch.cuda.is_bf16_supported = is_bf16_supported
94
+ pass
95
+
96
+ # Try loading bitsandbytes and triton
97
+ import bitsandbytes as bnb
98
+
99
+ if "SPACE_AUTHOR_NAME" not in os.environ and "SPACE_REPO_NAME" not in os.environ:
100
+
101
+ import triton
102
+ libcuda_dirs = lambda: None
103
+ if Version(triton.__version__) >= Version("3.0.0"):
104
+ try: from triton.backends.nvidia.driver import libcuda_dirs
105
+ except: pass
106
+ else: from triton.common.build import libcuda_dirs
107
+
108
+ try:
109
+ cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
110
+ libcuda_dirs()
111
+ except:
112
+ warnings.warn(
113
+ "Unsloth: Running `ldconfig /usr/lib64-nvidia` to link CUDA."\
114
+ )
115
+
116
+ if os.path.exists("/usr/lib64-nvidia"):
117
+ os.system("ldconfig /usr/lib64-nvidia")
118
+ elif os.path.exists("/usr/local"):
119
+ # Sometimes bitsandbytes cannot be linked properly in Runpod for example
120
+ possible_cudas = subprocess.check_output(["ls", "-al", "/usr/local"]).decode("utf-8").split("\n")
121
+ find_cuda = re.compile(r"[\s](cuda\-[\d\.]{2,})$")
122
+ possible_cudas = [find_cuda.search(x) for x in possible_cudas]
123
+ possible_cudas = [x.group(1) for x in possible_cudas if x is not None]
124
+
125
+ # Try linking cuda folder, or everything in local
126
+ if len(possible_cudas) == 0:
127
+ os.system(f"ldconfig /usr/local/")
128
+ else:
129
+ find_number = re.compile(r"([\d\.]{2,})")
130
+ latest_cuda = np.argsort([float(find_number.search(x).group(1)) for x in possible_cudas])[::-1][0]
131
+ latest_cuda = possible_cudas[latest_cuda]
132
+ os.system(f"ldconfig /usr/local/{latest_cuda}")
133
+ pass
134
+
135
+ importlib.reload(bnb)
136
+ importlib.reload(triton)
137
+ try:
138
+ libcuda_dirs = lambda: None
139
+ if Version(triton.__version__) >= Version("3.0.0"):
140
+ try: from triton.backends.nvidia.driver import libcuda_dirs
141
+ except: pass
142
+ else: from triton.common.build import libcuda_dirs
143
+ cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
144
+ libcuda_dirs()
145
+ except:
146
+ warnings.warn(
147
+ "Unsloth: CUDA is not linked properly.\n"\
148
+ "Try running `python -m bitsandbytes` then `python -m xformers.info`\n"\
149
+ "We tried running `ldconfig /usr/lib64-nvidia` ourselves, but it didn't work.\n"\
150
+ "You need to run in your terminal `sudo ldconfig /usr/lib64-nvidia` yourself, then import Unsloth.\n"\
151
+ "Also try `sudo ldconfig /usr/local/cuda-xx.x` - find the latest cuda version.\n"\
152
+ "Unsloth will still run for now, but maybe it might crash - let's hope it works!"
153
+ )
154
+ pass
155
+ pass
156
+
157
+ from .models import *
158
+ from .save import *
159
+ from .chat_templates import *
160
+ from .tokenizer_utils import *
161
+ from .trainer import *
unsloth-main/unsloth-main/unsloth/_auto_install.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ try: import torch
16
+ except: raise ImportError('Install torch via `pip install torch`')
17
+ from packaging.version import Version as V
18
+ v = V(torch.__version__)
19
+ cuda = str(torch.version.cuda)
20
+ is_ampere = torch.cuda.get_device_capability()[0] >= 8
21
+ if cuda != "12.1" and cuda != "11.8": raise RuntimeError(f"CUDA = {cuda} not supported!")
22
+ if v <= V('2.1.0'): raise RuntimeError(f"Torch = {v} too old!")
23
+ elif v <= V('2.1.1'): x = 'cu{}{}-torch211'
24
+ elif v <= V('2.1.2'): x = 'cu{}{}-torch212'
25
+ elif v < V('2.3.0'): x = 'cu{}{}-torch220'
26
+ elif v < V('2.4.0'): x = 'cu{}{}-torch230'
27
+ elif v < V('2.5.0'): x = 'cu{}{}-torch240'
28
+ else: raise RuntimeError(f"Torch = {v} too new!")
29
+ x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
30
+ print(f'pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"')
unsloth-main/unsloth-main/unsloth/chat_templates.py ADDED
@@ -0,0 +1,2210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ __all__ = [
16
+ "get_chat_template",
17
+ "test_chat_templates",
18
+ "test_hf_gguf_equivalence",
19
+ "remove_special_tokens",
20
+
21
+ "to_sharegpt",
22
+ "standardize_sharegpt",
23
+ "apply_chat_template",
24
+ "train_on_responses_only",
25
+
26
+ "test_construct_chat_template",
27
+ ]
28
+
29
+ from transformers import StoppingCriteria, StoppingCriteriaList
30
+ from torch import LongTensor, FloatTensor
31
+ from transformers.models.llama.modeling_llama import logger
32
+ from .save import patch_saving_functions
33
+ import os
34
+ import shutil
35
+ from .tokenizer_utils import *
36
+ from .models._utils import patch_tokenizer
37
+ import re
38
+
39
+ CHAT_TEMPLATES = {}
40
+
41
+ # =========================================== Unsloth
42
+ # Unsloth efficient template leverages from Zephyr
43
+ unsloth_template = \
44
+ "{{ bos_token }}"\
45
+ "{% if messages[0]['role'] == 'system' %}"\
46
+ "{{ messages[0]['content'] + '\n' }}"\
47
+ "{% set loop_messages = messages[1:] %}"\
48
+ "{% else %}"\
49
+ "{{ 'You are a helpful assistant to the user\n' }}"\
50
+ "{% set loop_messages = messages %}"\
51
+ "{% endif %}"\
52
+ "{% for message in loop_messages %}"\
53
+ "{% if message['role'] == 'user' %}"\
54
+ "{{ '>>> User: ' + message['content'] + '\n' }}"\
55
+ "{% elif message['role'] == 'assistant' %}"\
56
+ "{{ '>>> Assistant: ' + message['content'] + eos_token + '\n' }}"\
57
+ "{% else %}"\
58
+ "{{ raise_exception('Only user and assistant roles are supported!') }}"\
59
+ "{% endif %}"\
60
+ "{% endfor %}"\
61
+ "{% if add_generation_prompt %}"\
62
+ "{{ '>>> Assistant: ' }}"\
63
+ "{% endif %}"
64
+ pass
65
+
66
+ unsloth_ollama = \
67
+ '''
68
+ FROM {__FILE_LOCATION__}
69
+ TEMPLATE """{{ if .System }}{{ .System }}
70
+ {{ end }}{{ if .Prompt }}>>> User: {{ .Prompt }}
71
+ {{ end }}>>> Assistant: {{ .Response }}{__EOS_TOKEN__}
72
+ """
73
+ PARAMETER stop "{__EOS_TOKEN__}"
74
+ PARAMETER temperature 1.5
75
+ PARAMETER min_p 0.1
76
+ SYSTEM """You are a helpful assistant to the user"""
77
+ '''
78
+
79
+ unsloth_eos_token = "eos_token"
80
+ CHAT_TEMPLATES["unsloth"] = (unsloth_template, unsloth_eos_token, False, unsloth_ollama,)
81
+ pass
82
+
83
+ # =========================================== Zephyr
84
+ # Zephyr has no BOS!
85
+ zephyr_template = \
86
+ "{% for message in messages %}"\
87
+ "{% if message['role'] == 'user' %}"\
88
+ "{{ '<|user|>\n' + message['content'] + eos_token + '\n' }}"\
89
+ "{% elif message['role'] == 'assistant' %}"\
90
+ "{{ '<|assistant|>\n' + message['content'] + eos_token + '\n' }}"\
91
+ "{% else %}"\
92
+ "{{ '<|system|>\n' + message['content'] + eos_token + '\n' }}"\
93
+ "{% endif %}"\
94
+ "{% endfor %}"\
95
+ "{% if add_generation_prompt %}"\
96
+ "{{ '<|assistant|>\n' }}"\
97
+ "{% endif %}"
98
+ pass
99
+
100
+ zephyr_ollama = \
101
+ '''
102
+ FROM {__FILE_LOCATION__}
103
+ TEMPLATE """{{ if .System }}<|system|>
104
+ {{ .System }}{__EOS_TOKEN__}
105
+ {{ end }}{{ if .Prompt }}<|user|>
106
+ {{ .Prompt }}{__EOS_TOKEN__}
107
+ {{ end }}<|assistant|>
108
+ {{ .Response }}{__EOS_TOKEN__}
109
+ """
110
+ PARAMETER stop "{__EOS_TOKEN__}"
111
+ PARAMETER temperature 1.5
112
+ PARAMETER min_p 0.1
113
+ '''
114
+
115
+ zephyr_eos_token = "eos_token"
116
+ CHAT_TEMPLATES["zephyr"] = (zephyr_template, zephyr_eos_token, False, zephyr_ollama,)
117
+ pass
118
+
119
+ # =========================================== ChatML
120
+ # ChatML has no BOS and not EOS! Rather <|im_start|> and <|im_end|> acts as BOS / EOS.
121
+ chatml_template = \
122
+ "{% for message in messages %}"\
123
+ "{% if message['role'] == 'user' %}"\
124
+ "{{'<|im_start|>user\n' + message['content'] + '<|im_end|>\n'}}"\
125
+ "{% elif message['role'] == 'assistant' %}"\
126
+ "{{'<|im_start|>assistant\n' + message['content'] + '<|im_end|>\n' }}"\
127
+ "{% else %}"\
128
+ "{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>\n' }}"\
129
+ "{% endif %}"\
130
+ "{% endfor %}"\
131
+ "{% if add_generation_prompt %}"\
132
+ "{{ '<|im_start|>assistant\n' }}"\
133
+ "{% endif %}"
134
+ pass
135
+
136
+ chatml_ollama = \
137
+ '''
138
+ FROM {__FILE_LOCATION__}
139
+ TEMPLATE """{{ if .System }}<|im_start|>system
140
+ {{ .System }}<|im_end|>
141
+ {{ end }}{{ if .Prompt }}<|im_start|>user
142
+ {{ .Prompt }}<|im_end|>
143
+ {{ end }}<|im_start|>assistant
144
+ {{ .Response }}<|im_end|>
145
+ """
146
+ PARAMETER stop "<|im_start|>"
147
+ PARAMETER stop "<|im_end|>"
148
+ PARAMETER temperature 1.5
149
+ PARAMETER min_p 0.1
150
+ '''
151
+
152
+ chatml_eos_token = "<|im_end|>"
153
+ CHAT_TEMPLATES["chatml"] = (chatml_template, chatml_eos_token, True, chatml_ollama,)
154
+ pass
155
+
156
+ # =========================================== Mistral-1
157
+ # Mistral Instruct doesn't allow system prompts, so we append it to the user message.
158
+ mistral_template = \
159
+ "{{ bos_token }}"\
160
+ "{% if messages[0]['role'] == 'system' %}"\
161
+ "{% if messages[1]['role'] == 'user' %}"\
162
+ "{{ '[INST] ' + messages[0]['content'] + ' ' + messages[1]['content'] + ' [/INST]' }}"\
163
+ "{% set loop_messages = messages[2:] %}"\
164
+ "{% else %}"\
165
+ "{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}"\
166
+ "{% set loop_messages = messages[1:] %}"\
167
+ "{% endif %}"\
168
+ "{% else %}"\
169
+ "{% set loop_messages = messages %}"\
170
+ "{% endif %}"\
171
+ "{% for message in loop_messages %}"\
172
+ "{% if message['role'] == 'user' %}"\
173
+ "{{ '[INST] ' + message['content'] + ' [/INST]' }}"\
174
+ "{% elif message['role'] == 'assistant' %}"\
175
+ "{{ message['content'] + eos_token }}"\
176
+ "{% else %}"\
177
+ "{{ raise_exception('Only user and assistant roles are supported!') }}"\
178
+ "{% endif %}"\
179
+ "{% endfor %}"
180
+ pass
181
+
182
+ # Ollama from https://www.ollama.com/library/mistral
183
+ mistral_ollama = \
184
+ '''
185
+ FROM {__FILE_LOCATION__}
186
+ TEMPLATE """[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]"""
187
+ PARAMETER stop "{__EOS_TOKEN__}"
188
+ PARAMETER temperature 1.5
189
+ PARAMETER min_p 0.1
190
+ '''
191
+
192
+ mistral_eos_token = "eos_token"
193
+ CHAT_TEMPLATES["mistral"] = (mistral_template, mistral_eos_token, False, mistral_ollama,)
194
+ pass
195
+
196
+ # =========================================== Llama-2
197
+ # Adds BOS to every convo! And weird <<SYS>> system messages.
198
+ llama_template = \
199
+ "{% if messages[0]['role'] == 'system' %}"\
200
+ "{% if messages[1]['role'] == 'user' %}"\
201
+ "{{ bos_token + '[INST] <<SYS>>\n' + messages[0]['content'] + '\n<</SYS>>\n\n' + messages[1]['content'] + ' [/INST]' }}"\
202
+ "{% set loop_messages = messages[2:] %}"\
203
+ "{% else %}"\
204
+ "{{ bos_token + '[INST] ' + messages[0]['content'] + ' [/INST]' }}"\
205
+ "{% set loop_messages = messages[1:] %}"\
206
+ "{% endif %}"\
207
+ "{% else %}"\
208
+ "{% set loop_messages = messages %}"\
209
+ "{% endif %}"\
210
+ "{% for message in loop_messages %}"\
211
+ "{% if message['role'] == 'user' %}"\
212
+ "{{ bos_token + '[INST] ' + message['content'].strip() + ' [/INST]' }}"\
213
+ "{% elif message['role'] == 'assistant' %}"\
214
+ "{{ ' ' + message['content'].strip() + ' ' + eos_token }}"\
215
+ "{% else %}"\
216
+ "{{ raise_exception('Only user and assistant roles are supported!') }}"\
217
+ "{% endif %}"\
218
+ "{% endfor %}"
219
+ pass
220
+
221
+ # Ollama from https://www.ollama.com/library/llama3
222
+ llama_ollama = \
223
+ '''
224
+ FROM {__FILE_LOCATION__}
225
+ TEMPLATE """[INST] <<SYS>>{{ .System }}<</SYS>>
226
+
227
+ {{ .Prompt }} [/INST]"""
228
+ PARAMETER stop "{__EOS_TOKEN__}"
229
+ PARAMETER temperature 1.5
230
+ PARAMETER min_p 0.1
231
+ '''
232
+
233
+ llama_eos_token = "eos_token"
234
+ CHAT_TEMPLATES["llama"] = (llama_template, llama_eos_token, False, llama_ollama,)
235
+ pass
236
+
237
+ # =========================================== Vicuna
238
+ # https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
239
+ vicuna_template = \
240
+ "{{ bos_token }}"\
241
+ "{% if messages[0]['role'] == 'system' %}"\
242
+ "{{ messages[0]['content'] + ' ' }}"\
243
+ "{% set loop_messages = messages[1:] %}"\
244
+ "{% else %}"\
245
+ "{{ 'A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\\'s questions.' + ' ' }}"\
246
+ "{% set loop_messages = messages %}"\
247
+ "{% endif %}"\
248
+ "{% for message in loop_messages %}"\
249
+ "{% if message['role'] == 'user' %}"\
250
+ "{{ 'USER: ' + message['content'] + ' ' }}"\
251
+ "{% elif message['role'] == 'assistant' %}"\
252
+ "{{ 'ASSISTANT: ' + message['content'] + eos_token }}"\
253
+ "{% else %}"\
254
+ "{{ raise_exception('Only user and assistant roles are supported!') }}"\
255
+ "{% endif %}"\
256
+ "{% endfor %}"\
257
+ "{% if add_generation_prompt %}"\
258
+ "{{ 'ASSISTANT:' }}"\
259
+ "{% endif %}"
260
+ pass
261
+
262
+ # Ollama from https://www.ollama.com/library/vicuna
263
+ vicuna_ollama = \
264
+ '''
265
+ FROM {__FILE_LOCATION__}
266
+ TEMPLATE """{{ if .System }}{{ .System }} {{ end }}{{ if .Prompt }}USER: {{ .Prompt }} {{ end }}ASSISTANT: {{ .Response }} {__EOS_TOKEN__}"""
267
+ PARAMETER stop "{__EOS_TOKEN__}"
268
+ PARAMETER temperature 1.5
269
+ PARAMETER min_p 0.1
270
+ '''
271
+
272
+ vicuna_eos_token = "eos_token"
273
+ CHAT_TEMPLATES["vicuna"] = (vicuna_template, vicuna_eos_token, False, vicuna_ollama,)
274
+ pass
275
+
276
+ # =========================================== Vicuna Old
277
+ # https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
278
+ vicuna_old_template = \
279
+ "{{ bos_token }}"\
280
+ "{% if messages[0]['role'] == 'system' %}"\
281
+ "{{ messages[0]['content'] + '\n' }}"\
282
+ "{% set loop_messages = messages[1:] %}"\
283
+ "{% else %}"\
284
+ "{{ 'A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\\'s questions.' + '\n' }}"\
285
+ "{% set loop_messages = messages %}"\
286
+ "{% endif %}"\
287
+ "{% for message in loop_messages %}"\
288
+ "{% if message['role'] == 'user' %}"\
289
+ "{{ '### Human: ' + message['content'] + '\n' }}"\
290
+ "{% elif message['role'] == 'assistant' %}"\
291
+ "{{ '### Assistant: ' + message['content'] + eos_token + '\n' }}"\
292
+ "{% else %}"\
293
+ "{{ raise_exception('Only user and assistant roles are supported!') }}"\
294
+ "{% endif %}"\
295
+ "{% endfor %}"\
296
+ "{% if add_generation_prompt %}"\
297
+ "{{ '### Assistant:' }}"\
298
+ "{% endif %}"
299
+ pass
300
+
301
+ vicuna_old_ollama = \
302
+ '''
303
+ FROM {__FILE_LOCATION__}
304
+ TEMPLATE """{{ if .System }}{{ .System }}
305
+ {{ end }}{{ if .Prompt }}### Human: {{ .Prompt }}
306
+ {{ end }}### Assistant: {{ .Response }}{__EOS_TOKEN__}
307
+ """
308
+ PARAMETER stop "{__EOS_TOKEN__}"
309
+ PARAMETER temperature 1.5
310
+ PARAMETER min_p 0.1
311
+ SYSTEM """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."""
312
+ '''
313
+
314
+ vicuna_old_eos_token = "eos_token"
315
+ CHAT_TEMPLATES["vicuna_old"] = (vicuna_old_template, vicuna_old_eos_token, False, vicuna_old_ollama,)
316
+ pass
317
+
318
+ # =========================================== Alpaca multi turn
319
+ # https://github.com/tatsu-lab/stanford_alpaca Changed for multi-turn convos
320
+ alpaca_template = \
321
+ "{{ bos_token }}"\
322
+ "{% if messages[0]['role'] == 'system' %}"\
323
+ "{{ messages[0]['content'] + '\n\n' }}"\
324
+ "{% set loop_messages = messages[1:] %}"\
325
+ "{% else %}"\
326
+ "{{ 'Below are some instructions that describe some tasks. Write responses that appropriately complete each request.\n\n' }}"\
327
+ "{% set loop_messages = messages %}"\
328
+ "{% endif %}"\
329
+ "{% for message in loop_messages %}"\
330
+ "{% if message['role'] == 'user' %}"\
331
+ "{{ '### Instruction:\n' + message['content'] + '\n\n' }}"\
332
+ "{% elif message['role'] == 'assistant' %}"\
333
+ "{{ '### Response:\n' + message['content'] + eos_token + '\n\n' }}"\
334
+ "{% else %}"\
335
+ "{{ raise_exception('Only user and assistant roles are supported!') }}"\
336
+ "{% endif %}"\
337
+ "{% endfor %}"\
338
+ "{% if add_generation_prompt %}"\
339
+ "{{ '### Response:\n' }}"\
340
+ "{% endif %}"
341
+ pass
342
+
343
+ alpaca_ollama = \
344
+ '''
345
+ FROM {__FILE_LOCATION__}
346
+ TEMPLATE """{{ if .System }}{{ .System }}
347
+
348
+ {{ end }}{{ if .Prompt }}### Instruction:
349
+ {{ .Prompt }}{{ end }}
350
+
351
+ ### Response:
352
+ {{ .Response }}{__EOS_TOKEN__}
353
+
354
+ """
355
+ PARAMETER stop "{__EOS_TOKEN__}"
356
+ PARAMETER temperature 1.5
357
+ PARAMETER min_p 0.1
358
+ SYSTEM """Below are some instructions that describe some tasks. Write responses that appropriately complete each request."""
359
+ '''
360
+
361
+ alpaca_eos_token = "eos_token"
362
+ CHAT_TEMPLATES["alpaca"] = (alpaca_template, alpaca_eos_token, False, alpaca_ollama,)
363
+ pass
364
+
365
+ # =========================================== Gemma
366
+ # https://huggingface.co/google/gemma-7b-it
367
+ # Notice we must use |trim for lstrip and rstrip. <start_of_turn> maps to 106.
368
+ # <end_of_turn> maps to 107. user and model are normal 1 word tokens.
369
+ gemma_template = \
370
+ "{{ bos_token }}"\
371
+ "{% if messages[0]['role'] == 'system' %}"\
372
+ "{{'<start_of_turn>user\n' + messages[0]['content'] | trim + ' ' + messages[1]['content'] | trim + '<end_of_turn>\n'}}"\
373
+ "{% set loop_messages = messages[2:] %}"\
374
+ "{% endif %}"\
375
+ "{% for message in messages %}"\
376
+ "{% if message['role'] == 'user' %}"\
377
+ "{{'<start_of_turn>user\n' + message['content'] | trim + '<end_of_turn>\n'}}"\
378
+ "{% elif message['role'] == 'assistant' %}"\
379
+ "{{'<start_of_turn>model\n' + message['content'] | trim + '<end_of_turn>\n' }}"\
380
+ "{% else %}"\
381
+ "{{ raise_exception('Only user and assistant roles are supported!') }}"\
382
+ "{% endif %}"\
383
+ "{% endfor %}"\
384
+ "{% if add_generation_prompt %}"\
385
+ "{{ '<start_of_turn>model\n' }}"\
386
+ "{% endif %}"
387
+ pass
388
+
389
+ # Ollama from https://www.ollama.com/library/gemma
390
+ gemma_ollama = \
391
+ '''
392
+ FROM {__FILE_LOCATION__}
393
+ TEMPLATE """<start_of_turn>user
394
+ {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}<end_of_turn>
395
+ <start_of_turn>model
396
+ {{ .Response }}<end_of_turn>
397
+ """
398
+ PARAMETER repeat_penalty 1
399
+ PARAMETER stop "<start_of_turn>"
400
+ PARAMETER stop "<end_of_turn>"
401
+ PARAMETER penalize_newline false
402
+ PARAMETER temperature 1.5
403
+ PARAMETER min_p 0.1
404
+ '''
405
+
406
+ gemma_eos_token = "<end_of_turn>"
407
+ CHAT_TEMPLATES["gemma"] = (gemma_template, gemma_eos_token, True, gemma_ollama,)
408
+ pass
409
+
410
+ # =========================================== Gemma with ChatML instead
411
+ # We find using <eos> is still more appropriate!
412
+ gemma_chatml_template = "{{ bos_token }}" + chatml_template
413
+ pass
414
+
415
+ gemma_chatml_ollama = \
416
+ '''
417
+ FROM {__FILE_LOCATION__}
418
+ TEMPLATE """{{ if .System }}<|im_start|>system
419
+ {{ .System }}<|im_end|>
420
+ {{ end }}{{ if .Prompt }}<|im_start|>user
421
+ {{ .Prompt }}<|im_end|>
422
+ {{ end }}<|im_start|>assistant
423
+ {{ .Response }}<|im_end|>
424
+ """
425
+ PARAMETER repeat_penalty 1
426
+ PARAMETER stop "<|im_start|>"
427
+ PARAMETER stop "<|im_end|>"
428
+ PARAMETER penalize_newline false
429
+ PARAMETER temperature 1.5
430
+ PARAMETER min_p 0.1
431
+ '''
432
+
433
+ gemma_chatml_eos_token = (
434
+ {"<start_of_turn>" : "<|im_start|>", "<eos>" : "<|im_end|>"},
435
+ "<|im_end|>",
436
+ )
437
+ CHAT_TEMPLATES["gemma_chatml"] = (gemma_chatml_template, gemma_chatml_eos_token, True, gemma_chatml_ollama,)
438
+ pass
439
+
440
+ # =========================================== Gemma 2
441
+ # Same as Gemma 1, but with sliding window attention!
442
+ # https://ollama.com/library/gemma2/blobs/6522ca797f47
443
+ gemma2_template = gemma_template
444
+ gemma2_ollama = gemma_ollama + "PARAMETER num_ctx 4096\n"
445
+ gemma2_eos_token = "<end_of_turn>"
446
+ CHAT_TEMPLATES["gemma2"] = (gemma2_template, gemma2_eos_token, True, gemma2_ollama,)
447
+
448
+ # =========================================== Gemma 2 with ChatML instead
449
+ gemma2_chatml_template = gemma_chatml_template
450
+ gemma2_chatml_ollama = gemma_chatml_ollama + "PARAMETER num_ctx 4096\n"
451
+ gemma2_chatml_eos_token = gemma_chatml_eos_token
452
+ CHAT_TEMPLATES["gemma2_chatml"] = (gemma2_chatml_template, gemma2_chatml_eos_token, True, gemma2_chatml_ollama,)
453
+ pass
454
+
455
+ # =========================================== Llama-3
456
+ # Weirdly \n\n is needed?
457
+ llama3_template = \
458
+ "{{ bos_token }}"\
459
+ "{% for message in messages %}"\
460
+ "{% if message['role'] == 'user' %}"\
461
+ "{{ '<|start_header_id|>user<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}"\
462
+ "{% elif message['role'] == 'assistant' %}"\
463
+ "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}"\
464
+ "{% else %}"\
465
+ "{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}"\
466
+ "{% endif %}"\
467
+ "{% endfor %}"\
468
+ "{% if add_generation_prompt %}"\
469
+ "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"\
470
+ "{% endif %}"
471
+ pass
472
+
473
+ # Ollama from https://www.ollama.com/library/llama3
474
+ llama3_ollama = \
475
+ '''
476
+ FROM {__FILE_LOCATION__}
477
+ TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
478
+
479
+ {{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
480
+
481
+ {{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
482
+
483
+ {{ .Response }}<|eot_id|>"""
484
+ PARAMETER stop "<|start_header_id|>"
485
+ PARAMETER stop "<|end_header_id|>"
486
+ PARAMETER stop "<|eot_id|>"
487
+ PARAMETER temperature 1.5
488
+ PARAMETER min_p 0.1
489
+ '''
490
+
491
+ llama3_template_eos_token = "eos_token"
492
+ CHAT_TEMPLATES["llama-3"] = (llama3_template, llama3_template_eos_token, False, llama3_ollama,)
493
+ pass
494
+
495
+
496
+ # =========================================== Phi-3
497
+ # "{{ bos_token }}"\ # Phi-3.5 removes BOS?
498
+ phi3_template = \
499
+ "{% for message in messages %}"\
500
+ "{% if message['role'] == 'user' %}"\
501
+ "{{'<|user|>\n' + message['content'] + '<|end|>\n'}}"\
502
+ "{% elif message['role'] == 'assistant' %}"\
503
+ "{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}"\
504
+ "{% else %}"\
505
+ "{{'<|' + message['role'] + '|>\n' + message['content'] + '<|end|>\n'}}"\
506
+ "{% endif %}"\
507
+ "{% endfor %}"\
508
+ "{% if add_generation_prompt %}"\
509
+ "{{ '<|assistant|>\n' }}"\
510
+ "{% endif %}"
511
+ pass
512
+
513
+ # Ollama from https://www.ollama.com/library/phi3
514
+ phi3_ollama = \
515
+ '''
516
+ FROM {__FILE_LOCATION__}
517
+ TEMPLATE """{{ if .System }}<|system|>
518
+ {{ .System }}<|end|>
519
+ {{ end }}{{ if .Prompt }}<|user|>
520
+ {{ .Prompt }}<|end|>
521
+ {{ end }}<|assistant|>
522
+ {{ .Response }}<|end|>
523
+ """
524
+ PARAMETER stop "<|end|>"
525
+ PARAMETER stop "<|user|>"
526
+ PARAMETER stop "<|assistant|>"
527
+ PARAMETER temperature 1.5
528
+ PARAMETER min_p 0.1
529
+ '''
530
+
531
+ phi3_template_eos_token = "<|end|>"
532
+ CHAT_TEMPLATES["phi-3"] = (phi3_template, phi3_template_eos_token, False, phi3_ollama,)
533
+ CHAT_TEMPLATES["phi-35"] = CHAT_TEMPLATES["phi-3"]
534
+ CHAT_TEMPLATES["phi-3.5"] = CHAT_TEMPLATES["phi-3"]
535
+ pass
536
+
537
+ # =========================================== Llama-3.1
538
+ """
539
+ No trimming in Llama 3.1 Instruct!
540
+ Also an extra newline for Cutting Knowledge Date
541
+ See https://colab.research.google.com/drive/1Xpqq5xpIgO-B00MQ-UccYMwN2J8QFgBM?usp=sharing
542
+
543
+ Also should be
544
+
545
+ import datetime
546
+ tokenizer.apply_chat_template(
547
+ messages,
548
+ add_generation_prompt = True,
549
+ tokenize = False,
550
+ date_string = datetime.today().strftime("%d %B %Y")),
551
+ )
552
+ """
553
+
554
+ llama31_template = \
555
+ """{{- bos_token }}
556
+ {%- if custom_tools is defined %}
557
+ {%- set tools = custom_tools %}
558
+ {%- endif %}
559
+ {%- if not tools_in_user_message is defined %}
560
+ {%- set tools_in_user_message = true %}
561
+ {%- endif %}
562
+ {%- if not date_string is defined %}
563
+ {%- set date_string = "26 July 2024" %}
564
+ {%- endif %}
565
+ {%- if not tools is defined %}
566
+ {%- set tools = none %}
567
+ {%- endif %}
568
+
569
+ {#- This block extracts the system message, so we can slot it into the right place. #}
570
+ {%- if messages[0]['role'] == 'system' %}
571
+ {%- set system_message = messages[0]['content'] %}
572
+ {%- set messages = messages[1:] %}
573
+ {%- else %}
574
+ {%- set system_message = "" %}
575
+ {%- endif %}
576
+
577
+ {#- System message + builtin tools #}
578
+ {{- "<|start_header_id|>system<|end_header_id|>\n\n" }}
579
+ {%- if builtin_tools is defined or tools is not none %}
580
+ {{- "Environment: ipython\n" }}
581
+ {%- endif %}
582
+ {%- if builtin_tools is defined %}
583
+ {{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\n\n"}}
584
+ {%- endif %}
585
+ {{- "Cutting Knowledge Date: December 2023\n" }}
586
+ {{- "Today Date: " + date_string + "\n\n" }}
587
+ {%- if tools is not none and not tools_in_user_message %}
588
+ {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}
589
+ {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
590
+ {{- "Do not use variables.\n\n" }}
591
+ {%- for t in tools %}
592
+ {{- t | tojson(indent=4) }}
593
+ {{- "\n\n" }}
594
+ {%- endfor %}
595
+ {%- endif %}
596
+ {{- system_message }}
597
+ {{- "<|eot_id|>" }}
598
+
599
+ {#- Custom tools are passed in a user message with some extra guidance #}
600
+ {%- if tools_in_user_message and not tools is none %}
601
+ {#- Extract the first user message so we can plug it in here #}
602
+ {%- if messages | length != 0 %}
603
+ {%- set first_user_message = messages[0]['content'] %}
604
+ {%- set messages = messages[1:] %}
605
+ {%- else %}
606
+ {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
607
+ {%- endif %}
608
+ {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}}
609
+ {{- "Given the following functions, please respond with a JSON for a function call " }}
610
+ {{- "with its proper arguments that best answers the given prompt.\n\n" }}
611
+ {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
612
+ {{- "Do not use variables.\n\n" }}
613
+ {%- for t in tools %}
614
+ {{- t | tojson(indent=4) }}
615
+ {{- "\n\n" }}
616
+ {%- endfor %}
617
+ {{- first_user_message + "<|eot_id|>"}}
618
+ {%- endif %}
619
+
620
+ {%- for message in messages %}
621
+ {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
622
+ {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }}
623
+ {%- elif 'tool_calls' in message %}
624
+ {%- if not message.tool_calls|length == 1 %}
625
+ {{- raise_exception("This model only supports single tool-calls at once!") }}
626
+ {%- endif %}
627
+ {%- set tool_call = message.tool_calls[0].function %}
628
+ {%- if builtin_tools is defined and tool_call.name in builtin_tools %}
629
+ {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
630
+ {{- "<|python_tag|>" + tool_call.name + ".call(" }}
631
+ {%- for arg_name, arg_val in tool_call.arguments | items %}
632
+ {{- arg_name + '="' + arg_val + '"' }}
633
+ {%- if not loop.last %}
634
+ {{- ", " }}
635
+ {%- endif %}
636
+ {%- endfor %}
637
+ {{- ")" }}
638
+ {%- else %}
639
+ {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
640
+ {{- '{"name": "' + tool_call.name + '", ' }}
641
+ {{- '"parameters": ' }}
642
+ {{- tool_call.arguments | tojson }}
643
+ {{- "}" }}
644
+ {%- endif %}
645
+ {%- if builtin_tools is defined %}
646
+ {#- This means we're in ipython mode #}
647
+ {{- "<|eom_id|>" }}
648
+ {%- else %}
649
+ {{- "<|eot_id|>" }}
650
+ {%- endif %}
651
+ {%- elif message.role == "tool" or message.role == "ipython" %}
652
+ {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }}
653
+ {%- if message.content is mapping or message.content is iterable %}
654
+ {{- message.content | tojson }}
655
+ {%- else %}
656
+ {{- message.content }}
657
+ {%- endif %}
658
+ {{- "<|eot_id|>" }}
659
+ {%- endif %}
660
+ {%- endfor %}
661
+ {%- if add_generation_prompt %}
662
+ {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
663
+ {%- endif %}
664
+ """
665
+ pass
666
+
667
+ # Ollama from https://ollama.com/library/llama3.1 (needs updating!)
668
+ llama31_ollama = \
669
+ '''
670
+ FROM {__FILE_LOCATION__}
671
+ TEMPLATE """{{ if .Messages }}
672
+ {{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
673
+ {{- if .System }}
674
+
675
+ {{ .System }}
676
+ {{- end }}
677
+ {{- if .Tools }}
678
+
679
+ You are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the orginal use question.
680
+ {{- end }}
681
+ {{- end }}<|eot_id|>
682
+ {{- range $i, $_ := .Messages }}
683
+ {{- $last := eq (len (slice $.Messages $i)) 1 }}
684
+ {{- if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>
685
+ {{- if and $.Tools $last }}
686
+
687
+ Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.
688
+
689
+ Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
690
+
691
+ {{ $.Tools }}
692
+ {{- end }}
693
+
694
+ {{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
695
+
696
+ {{ end }}
697
+ {{- else if eq .Role "assistant" }}<|start_header_id|>assistant<|end_header_id|>
698
+ {{- if .ToolCalls }}
699
+
700
+ {{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}
701
+ {{- else }}
702
+
703
+ {{ .Content }}{{ if not $last }}<|eot_id|>{{ end }}
704
+ {{- end }}
705
+ {{- else if eq .Role "tool" }}<|start_header_id|>ipython<|end_header_id|>
706
+
707
+ {{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
708
+
709
+ {{ end }}
710
+ {{- end }}
711
+ {{- end }}
712
+ {{- else }}
713
+ {{- if .System }}<|start_header_id|>system<|end_header_id|>
714
+
715
+ {{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
716
+
717
+ {{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
718
+
719
+ {{ end }}{{ .Response }}{{ if .Response }}<|eot_id|>{{ end }}"""
720
+ PARAMETER stop "<|start_header_id|>"
721
+ PARAMETER stop "<|end_header_id|>"
722
+ PARAMETER stop "<|eot_id|>"
723
+ PARAMETER stop "<|eom_id|>"
724
+ PARAMETER temperature 1.5
725
+ PARAMETER min_p 0.1
726
+ '''
727
+
728
+ llama31_template_eos_token = "eos_token"
729
+ CHAT_TEMPLATES["llama-3.1"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,)
730
+ CHAT_TEMPLATES["llama-31"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,)
731
+ pass
732
+
733
+
734
+ # =========================================== Qwen 2.5
735
+ qwen25_template = \
736
+ """{%- if tools %}
737
+ {{- \'<|im_start|>system\\n\' }}
738
+ {%- if messages[0][\'role\'] == \'system\' %}
739
+ {{- messages[0][\'content\'] }}
740
+ {%- else %}
741
+ {{- \'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\' }}
742
+ {%- endif %}
743
+ {{- "\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>" }}
744
+ {%- for tool in tools %}
745
+ {{- "\\n" }}
746
+ {{- tool | tojson }}
747
+ {%- endfor %}
748
+ {{- "\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\"name\\": <function-name>, \\"arguments\\": <args-json-object>}\\n</tool_call><|im_end|>\\n" }}\n{%- else %}
749
+ {%- if messages[0][\'role\'] == \'system\' %}
750
+ {{- \'<|im_start|>system\\n\' + messages[0][\'content\'] + \'<|im_end|>\\n\' }}
751
+ {%- else %}
752
+ {{- \'<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n\' }}
753
+ {%- endif %}\n{%- endif %}\n{%- for message in messages %}
754
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
755
+ {{- \'<|im_start|>\' + message.role + \'\\n\' + message.content + \'<|im_end|>\' + \'\\n\' }}
756
+ {%- elif message.role == "assistant" %}
757
+ {{- \'<|im_start|>\' + message.role }}
758
+ {%- if message.content %}
759
+ {{- \'\\n\' + message.content }}
760
+ {%- endif %}
761
+ {%- for tool_call in message.tool_calls %}
762
+ {%- if tool_call.function is defined %}
763
+ {%- set tool_call = tool_call.function %}
764
+ {%- endif %}
765
+ {{- \'\\n<tool_call>\\n{"name": "\' }}
766
+ {{- tool_call.name }}
767
+ {{- \'", "arguments": \' }}
768
+ {{- tool_call.arguments | tojson }}
769
+ {{- \'}\\n</tool_call>\' }}
770
+ {%- endfor %}
771
+ {{- \'<|im_end|>\\n\' }}
772
+ {%- elif message.role == "tool" %}
773
+ {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} {{- \'<|im_start|>user\' }}
774
+ {%- endif %}
775
+ {{- \'\\n<tool_response>\\n\' }}
776
+ {{- message.content }}
777
+ {{- \'\\n</tool_response>\' }}
778
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
779
+ {{- \'<|im_end|>\\n\' }}
780
+ {%- endif %}
781
+ {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}
782
+ {{- \'<|im_start|>assistant\\n\' }}
783
+ {%- endif %}
784
+ """
785
+
786
+
787
+ # Ollama from https://ollama.com/library/qwen2.5/blobs/eb4402837c78
788
+ qwen25_ollama = \
789
+ '''
790
+ FROM {__FILE_LOCATION__}
791
+ TEMPLATE """{{- if .Messages }}
792
+ {{- if or .System .Tools }}<|im_start|>system
793
+ {{- if .System }}
794
+ {{ .System }}
795
+ {{- end }}
796
+ {{- if .Tools }}
797
+
798
+ # Tools
799
+
800
+ You may call one or more functions to assist with the user query.
801
+
802
+ You are provided with function signatures within <tools></tools> XML tags:
803
+ <tools>
804
+ {{- range .Tools }}
805
+ {"type": "function", "function": {{ .Function }}}
806
+ {{- end }}
807
+ </tools>
808
+
809
+ For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
810
+ <tool_call>
811
+ {"name": <function-name>, "arguments": <args-json-object>}
812
+ </tool_call>
813
+ {{- end }}<|im_end|>
814
+ {{ end }}
815
+ {{- range $i, $_ := .Messages }}
816
+ {{- $last := eq (len (slice $.Messages $i)) 1 -}}
817
+ {{- if eq .Role "user" }}<|im_start|>user
818
+ {{ .Content }}<|im_end|>
819
+ {{ else if eq .Role "assistant" }}<|im_start|>assistant
820
+ {{ if .Content }}{{ .Content }}
821
+ {{- else if .ToolCalls }}<tool_call>
822
+ {{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
823
+ {{ end }}</tool_call>
824
+ {{- end }}{{ if not $last }}<|im_end|>
825
+ {{ end }}
826
+ {{- else if eq .Role "tool" }}<|im_start|>user
827
+ <tool_response>
828
+ {{ .Content }}
829
+ </tool_response><|im_end|>
830
+ {{ end }}
831
+ {{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
832
+ {{ end }}
833
+ {{- end }}
834
+ {{- else }}
835
+ {{- if .System }}<|im_start|>system
836
+ {{ .System }}<|im_end|>
837
+ {{ end }}{{ if .Prompt }}<|im_start|>user
838
+ {{ .Prompt }}<|im_end|>
839
+ {{ end }}<|im_start|>assistant
840
+ {{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}"""
841
+ PARAMETER stop "<|im_end|>"
842
+ PARAMETER stop "<|endoftext|>"
843
+ PARAMETER temperature 1.5
844
+ PARAMETER min_p 0.1
845
+ '''
846
+
847
+ qwen25_template_eos_token = "eos_token"
848
+ CHAT_TEMPLATES["qwen-2.5"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
849
+ CHAT_TEMPLATES["qwen-25"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
850
+ CHAT_TEMPLATES["qwen25"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
851
+ CHAT_TEMPLATES["qwen2.5"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
852
+ pass
853
+
854
+
855
+ def get_chat_template(
856
+ tokenizer,
857
+ chat_template = "chatml",
858
+ mapping = {"role" : "role", "content" : "content", "user" : "user", "assistant" : "assistant"},
859
+ map_eos_token = True,
860
+ system_message = None,
861
+ ):
862
+ assert(type(map_eos_token) is bool)
863
+ old_tokenizer = tokenizer
864
+
865
+ IS_GEMMA = False
866
+ if tokenizer.__class__.__name__.startswith("Gemma"):
867
+ if chat_template == "chatml": chat_template = "gemma_chatml"
868
+ IS_GEMMA = True
869
+ pass
870
+
871
+ # We add a check for Llama-3
872
+ # if chat_template == "llama-3":
873
+ # tokenizer._using_llama3_template = True
874
+ # else:
875
+ # llama3_tokens = set(["<|end_header_id|>", "<|eot_id|>", "<|start_header_id|>"])
876
+ # check_llama3_tokens = llama3_tokens & set(str(x) for x in tokenizer.added_tokens_decoder.values())
877
+ # if len(check_llama3_tokens) == len(llama3_tokens):
878
+ # tokenizer._using_llama3_template = True
879
+ # pass
880
+ # pass
881
+
882
+ # We first check if the tokenizer is a fast one. If not, we cannot convert this!
883
+ is_fast_tokenizer = getattr(tokenizer, "is_fast", False)
884
+ old_padding_side = tokenizer.padding_side
885
+
886
+ same_padding_token = False
887
+
888
+ if type(chat_template) in (list, tuple,):
889
+ chat_template, stop_word = chat_template
890
+ assert(type(chat_template) is str)
891
+ assert(type(stop_word) is str)
892
+ ollama_modelfile = None
893
+
894
+ elif type(chat_template) is str:
895
+
896
+ chat_template, stop_word, yes_map_eos_token, ollama_modelfile = CHAT_TEMPLATES[chat_template]
897
+
898
+ # Check mapping to eos_token
899
+ if not map_eos_token and yes_map_eos_token: map_eos_token = True
900
+ if not yes_map_eos_token and map_eos_token: map_eos_token = False
901
+
902
+ if type(stop_word) in (list, tuple,):
903
+ token_mapping, stop_word = stop_word
904
+ assert(type(token_mapping) is dict)
905
+ else:
906
+ token_mapping = None
907
+
908
+ assert(type(stop_word) is str)
909
+
910
+ # Check fast tokenizer
911
+ if not is_fast_tokenizer:
912
+ print(
913
+ f"Unsloth: Not a fast tokenizer, so can't process it as of yet :(\n"\
914
+ "Please log a Github issue if you want this as a new feature!\n"\
915
+ "Your chat template will still work, but it won't add or edit tokens."
916
+ )
917
+
918
+ elif token_mapping is not None:
919
+ # token_mapping = {"<start_of_turn>" : "<|im_start|>", "<end_of_turn>" : "<|im_end|>"}
920
+ # For Gemma :)
921
+
922
+ string_vocab = tokenizer._tokenizer.to_str()
923
+
924
+ skipped = 0
925
+ for old_token, new_token in token_mapping.items():
926
+ old_count = string_vocab.count(f'"{old_token}"')
927
+ new_count = string_vocab.count(f'"{new_token}"')
928
+ if new_count != 0:
929
+ print(f"{new_token} is already a token. Skipping.")
930
+ skipped += 1
931
+ elif old_count == 0:
932
+ raise RuntimeError(f"{old_token} was not part of the tokenizer!")
933
+ else:
934
+ string_vocab = string_vocab.replace(f'"{old_token}"', f'"{new_token}"')
935
+ pass
936
+ pass
937
+
938
+ if map_eos_token and (not stop_word in token_mapping.values()):
939
+ # Do not map 107 = <|im_end|> and 1 = <|im_end|>. This will reduce the vocab size by 1
940
+ logger.warning_once(f"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}.")
941
+ string_vocab = string_vocab.replace(tokenizer.eos_token, stop_word)
942
+ pass
943
+
944
+ if skipped != len(token_mapping):
945
+ new_tokenizer = tokenizer._tokenizer.from_str(string_vocab)
946
+
947
+ # Careful on pad_token
948
+ old_pad_token = tokenizer.pad_token
949
+ if old_pad_token == tokenizer.eos_token:
950
+ old_pad_token = stop_word
951
+ same_padding_token = True
952
+ pass
953
+
954
+ if map_eos_token:
955
+ new_tokenizer = tokenizer.__class__(
956
+ tokenizer_object = new_tokenizer,
957
+ eos_token = stop_word,
958
+ pad_token = old_pad_token,
959
+ )
960
+ else:
961
+ new_tokenizer = tokenizer.__class__(
962
+ tokenizer_object = new_tokenizer,
963
+ pad_token = old_pad_token,
964
+ )
965
+ pass
966
+
967
+ # Must fix the sentence piece tokenizer since there's no tokenizer.model file!
968
+ tokenizer = fix_sentencepiece_tokenizer(tokenizer, new_tokenizer, token_mapping,)
969
+ else:
970
+ pass
971
+
972
+ elif map_eos_token and (stop_word != "eos_token"):
973
+ logger.warning_once(f"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}.")
974
+
975
+ # Replaces the old EOS token with a new one.
976
+ # Useful for ChatML <|im_end|> for example.
977
+ # Usually we train 2 more tokens <|im_start|> and <|im_end|>
978
+ # But training the lm_head and embeddings are slow!
979
+ # This is a HACK!
980
+ # Idea from https://huggingface.co/cognitivecomputations/dolphin-2.6-mistral-7b-dpo-laser
981
+
982
+ old_bos_token = getattr(tokenizer, "bos_token", None)
983
+ old_eos_token = getattr(tokenizer, "eos_token", None)
984
+ old_pad_token = getattr(tokenizer, "pad_token", None)
985
+ old_unk_token = getattr(tokenizer, "unk_token", None)
986
+
987
+ string_vocab = tokenizer._tokenizer.to_str()
988
+ # First check if new stop_word is in the tokenizer
989
+ if stop_word in string_vocab:
990
+ # We shall swap them around
991
+ temporary_stop_token = "<|:__TEMP//STOP//TOKEN__:|>"
992
+ string_vocab = string_vocab.replace(old_eos_token, temporary_stop_token)
993
+ string_vocab = string_vocab.replace(stop_word, old_eos_token)
994
+ string_vocab = string_vocab.replace(temporary_stop_token, stop_word)
995
+ else:
996
+ string_vocab = string_vocab.replace(old_eos_token, stop_word)
997
+ pass
998
+ new_tokenizer = tokenizer._tokenizer.from_str(string_vocab)
999
+
1000
+ # Careful on pad_token
1001
+ if old_pad_token == old_eos_token:
1002
+ old_pad_token = stop_word
1003
+ same_padding_token = True
1004
+ pass
1005
+
1006
+ new_tokenizer = tokenizer.__class__(
1007
+ tokenizer_object = new_tokenizer,
1008
+ bos_token = old_bos_token,
1009
+ eos_token = stop_word,
1010
+ unk_token = old_unk_token,
1011
+ pad_token = old_pad_token,
1012
+ )
1013
+
1014
+ # Must fix the sentence piece tokenizer since there's no tokenizer.model file!
1015
+ token_mapping = { old_eos_token : stop_word, }
1016
+ tokenizer = fix_sentencepiece_tokenizer(tokenizer, new_tokenizer, token_mapping,)
1017
+ pass
1018
+
1019
+ else:
1020
+ raise TypeError(
1021
+ f"Unsloth: `chat_template` must be a tuple of (your_template, eos_token,) or one of\n"\
1022
+ f"{CHAT_TEMPLATES.keys()}"
1023
+ )
1024
+ pass
1025
+
1026
+ # Careful on Gemma
1027
+ # bos_token is a must or else losses become too high
1028
+ if IS_GEMMA and not chat_template.startswith(("{{ bos_token }}", "{{- bos_token }}")):
1029
+ chat_template = "{{ bos_token }}" + chat_template
1030
+ pass
1031
+
1032
+ # For ShareGPT role -> from and content -> value
1033
+ new_chat_template = chat_template\
1034
+ .replace("'role'", "'" + mapping["role"] + "'")\
1035
+ .replace("'content'", "'" + mapping["content"] + "'")\
1036
+ .replace("'user'", "'" + mapping["user"] + "'")\
1037
+ .replace("'assistant'", "'" + mapping["assistant"] + "'")
1038
+
1039
+ _, tokenizer = patch_tokenizer(model = None, tokenizer = tokenizer)
1040
+ tokenizer.padding_side = old_padding_side
1041
+
1042
+ # If not normal HF, we add a check to make old templates work
1043
+ if mapping != {"role" : "role", "content" : "content", "user" : "user", "assistant" : "assistant"}:
1044
+ chat_template = \
1045
+ "{% if 'role' in messages[0] %}" + \
1046
+ chat_template + \
1047
+ "{% else %}" + \
1048
+ new_chat_template + \
1049
+ "{% endif %}"
1050
+ else:
1051
+ chat_template = new_chat_template
1052
+ pass
1053
+ tokenizer.chat_template = chat_template
1054
+
1055
+ # Also fix up other tokens
1056
+ old_pad_token = getattr(old_tokenizer, "pad_token", None)
1057
+ old_bos_token = getattr(old_tokenizer, "bos_token", None)
1058
+ old_unk_token = getattr(old_tokenizer, "unk_token", None)
1059
+ new_pad_token = getattr(tokenizer, "pad_token", None)
1060
+ new_bos_token = getattr(tokenizer, "bos_token", None)
1061
+ new_unk_token = getattr(tokenizer, "unk_token", None)
1062
+ if old_bos_token != new_bos_token: tokenizer.bos_token = old_bos_token
1063
+ if old_unk_token != new_unk_token: tokenizer.unk_token = old_unk_token
1064
+ if not same_padding_token:
1065
+ if old_pad_token != new_pad_token: tokenizer.pad_token = old_pad_token
1066
+ pass
1067
+
1068
+ # stopping_criteria = create_stopping_criteria(tokenizer, stop_word)
1069
+
1070
+ # Patch saving functions
1071
+ tokenizer = patch_saving_functions(tokenizer)
1072
+
1073
+ # Add Ollama
1074
+ tokenizer._ollama_modelfile = ollama_modelfile
1075
+ tokenizer._system_message = system_message
1076
+ return tokenizer#, stopping_criteria
1077
+ pass
1078
+
1079
+
1080
+ def remove_special_tokens(tokenizer, prompt):
1081
+ # Removes double BOS token
1082
+ if prompt.startswith(tokenizer.bos_token):
1083
+ prompt = prompt[len(tokenizer.bos_token):]
1084
+ pass
1085
+ return prompt
1086
+ pass
1087
+
1088
+
1089
+ def _parse_combined_prompt(combined_prompt, dataset):
1090
+ # Find {...}
1091
+ possible_columns = re.findall(r"\{(.+?)\}", combined_prompt)
1092
+ dataset_columns = set(dataset.column_names)
1093
+ for column in possible_columns:
1094
+ if column not in dataset_columns:
1095
+ raise KeyError(
1096
+ f"Unsloth: Your prompt includes '{column}' but this does not exist in the dataset. "\
1097
+ f"Only allowed columns are {list(dataset_columns)}"
1098
+ )
1099
+ pass
1100
+ pass
1101
+
1102
+ # Find [[...]]
1103
+ optional_prompts = list(re.finditer(r"\[\[.+?\]\]", combined_prompt, flags = re.DOTALL | re.MULTILINE))
1104
+ optional_prompts = [(x.span(), x.group(0)) for x in optional_prompts]
1105
+
1106
+ final_optional_prompts = []
1107
+ if len(optional_prompts) != 0:
1108
+ # Add left
1109
+ left = optional_prompts[0]
1110
+ l = left[0][0]
1111
+ if l != 0: final_optional_prompts.append(combined_prompt[:l])
1112
+
1113
+ # Add in between
1114
+ for left, right in zip(optional_prompts[:-1], optional_prompts[1:]):
1115
+ l, r = left[0][-1], right[0][0]
1116
+ final_optional_prompts.append(left)
1117
+ if l != r: final_optional_prompts.append(combined_prompt[l : r])
1118
+ pass
1119
+ final_optional_prompts.append(optional_prompts[-1])
1120
+
1121
+ # Add right
1122
+ right = optional_prompts[-1]
1123
+ r = right[0][1]
1124
+ if r != len(combined_prompt): final_optional_prompts.append(combined_prompt[r:])
1125
+ else:
1126
+ # Just add in the entire string
1127
+ final_optional_prompts.append(combined_prompt)
1128
+ pass
1129
+
1130
+ check_combined = "".join(x if type(x) is str else x[1] for x in final_optional_prompts)
1131
+ assert(combined_prompt == check_combined)
1132
+
1133
+ return possible_columns, final_optional_prompts
1134
+ pass
1135
+
1136
+
1137
+ def _create_formatter(possible_columns, final_optional_prompts, user_column_name):
1138
+ # Start final prompt!
1139
+ function = ["def __combined_prompt_processor__(examples):"]
1140
+ columns = list(set(possible_columns))
1141
+ for column in columns:
1142
+ function.append(f"{' '*4}{column}__ = examples['{column}']")
1143
+ function.append(f"{' '*4}texts = []")
1144
+ function.append(f"{' '*4}for ({', '.join(columns)}) in zip({', '.join(f'{x}__' for x in columns)}):")
1145
+
1146
+ # Add optional tags as well!
1147
+ final_prompt = ""
1148
+ formatter = []
1149
+
1150
+ for j, optional_prompt in enumerate(final_optional_prompts):
1151
+ if type(optional_prompt) is str:
1152
+ columns = re.findall(r"\{(.+?)\}", optional_prompt)
1153
+ formatter += columns
1154
+ # Must escape \n \r
1155
+ final_prompt += optional_prompt.encode("unicode-escape").decode("utf-8").replace("'", "\\'").replace('"', '\\"')
1156
+ else:
1157
+ where, prompt = optional_prompt
1158
+ # Strip [[...]]
1159
+ # Must escape \n \r
1160
+ prompt = prompt[2:-2].encode("unicode-escape").decode("utf-8").replace("'", "\\'").replace('"', '\\"')
1161
+ columns = re.findall(r"\{(.+?)\}", prompt)
1162
+ x = f"__optional_{j}__"
1163
+ prompt = f"{' '*8}{x} = '{prompt}'.format({', '.join(f'{x} = {x}' for x in columns)}) if {columns[0]} else ''"
1164
+ function.append(prompt)
1165
+ formatter.append(x)
1166
+ final_prompt += "{" + x + "}"
1167
+ pass
1168
+ pass
1169
+
1170
+ function.insert(1, f"{' '*4}__combined_prompt__ = '{final_prompt}'")
1171
+ function.append(f"{' '*8}texts.append("\
1172
+ f"__combined_prompt__.format({', '.join(f'{x} = {x}' for x in formatter)}))")
1173
+ function.append(f"{' '*4}return " + "{ " + f"'{user_column_name}' : texts" + " }")
1174
+ return "\n".join(function)
1175
+ pass
1176
+
1177
+
1178
+ def to_sharegpt(
1179
+ dataset,
1180
+ merged_prompt = "",
1181
+ merged_column_name = "instruction",
1182
+ output_column_name = "output",
1183
+ remove_unused_columns = True,
1184
+ conversation_extension = 1,
1185
+ random_state = 3407,
1186
+ ):
1187
+ """
1188
+ Converts a dataset to ShareGPT style.
1189
+ ShareGPT requires only 1 input and 1 output field.
1190
+ This means one has to merge multiple columns into 1 for 1 input field.
1191
+ Use `conversation_extension` to increase the length of each conversation by randomnly
1192
+ selecting a few and packing them into 1.
1193
+
1194
+ merged_prompt = "", Prompt to merge columns into 1 input
1195
+ merged_column_name = "instruction", Final column name for the input field
1196
+ output_column_name = "output", Final column name for the output field
1197
+ remove_unused_columns = True,
1198
+ conversation_extension = 1, Automatically combines `conversation_extension` convos into 1
1199
+ random_state = 3407,
1200
+ """
1201
+ if "conversations" in dataset.column_names:
1202
+ convo = dataset[0]["conversations"]
1203
+ if type(convo) is list:
1204
+ raise TypeError("Unsloth: Your dataset is probably already in ShareGPT format!")
1205
+ pass
1206
+ pass
1207
+
1208
+ possible_columns, final_optional_prompts = _parse_combined_prompt(merged_prompt, dataset)
1209
+ function = _create_formatter(possible_columns, final_optional_prompts, merged_column_name)
1210
+ exec(function, globals())
1211
+ dataset = dataset.map(__combined_prompt_processor__, batched = True, desc = "Merging columns")
1212
+
1213
+ def __convert_to_sharegpt__(examples):
1214
+ users = examples[merged_column_name]
1215
+ assistants = examples[output_column_name]
1216
+ texts = [
1217
+ [
1218
+ {"from" : "human", "value" : str(user) },
1219
+ {"from" : "gpt", "value" : str(assistant)},
1220
+ ] \
1221
+ for user, assistant in zip(users, assistants)
1222
+ ]
1223
+ return { "conversations" : texts, }
1224
+ pass
1225
+
1226
+ dataset = dataset.map(
1227
+ __convert_to_sharegpt__,
1228
+ batched = True,
1229
+ desc = "Converting to ShareGPT",
1230
+ # Remove unused columns!
1231
+ remove_columns = dataset.column_names if remove_unused_columns else None,
1232
+ )
1233
+
1234
+ # Randomnly concat conversations to create a long stream!
1235
+ from datasets import concatenate_datasets
1236
+ n_extensions = max(conversation_extension-1, 0)
1237
+ if n_extensions == 0: return dataset
1238
+
1239
+ dataset = dataset.rename_columns({"conversations" : f"conversations0"})
1240
+ all_shuffled = [dataset]
1241
+ for j in range(1, n_extensions+1):
1242
+ shuffled = dataset.shuffle(seed = random_state+j).rename_columns({"conversations0" : f"conversations{j}"})
1243
+ all_shuffled.append(shuffled)
1244
+ pass
1245
+ dataset = concatenate_datasets(all_shuffled, axis = 1)
1246
+
1247
+ # Combine them into 1
1248
+ function = "def __combine_conversations__(examples):\n"
1249
+ n_extensions += 1
1250
+ for j in range(n_extensions):
1251
+ function += f"{' '*4}conversations{j}__ = examples['conversations{j}']\n"
1252
+ function += f"{' '*4}convos = []\n"
1253
+ function += f"{' '*4}for ({', '.join(f'conversations{j}' for j in range(n_extensions))}) "\
1254
+ f"in zip({', '.join(f'conversations{j}__' for j in range(n_extensions))}):\n"
1255
+ function += f"{' '*8}convos.append("\
1256
+ f"{'+'.join(f'conversations{j}' for j in range(n_extensions))})\n"
1257
+ function += f"{' '*4}return " + "{ " + f"'conversations' : convos" + " }"
1258
+
1259
+ # Map function
1260
+ exec(function, globals())
1261
+ dataset = dataset.map(
1262
+ __combine_conversations__,
1263
+ batched = True,
1264
+ desc = "Extending conversations",
1265
+ # Remove unused columns!
1266
+ remove_columns = dataset.column_names if remove_unused_columns else None,
1267
+ )
1268
+ return dataset
1269
+ pass
1270
+
1271
+
1272
+ def standardize_sharegpt(
1273
+ dataset,
1274
+ aliases_for_system = ["system",],
1275
+ aliases_for_user = ["user", "human", "input",],
1276
+ aliases_for_assistant = ["gpt", "assistant", "output",],
1277
+ ):
1278
+ """
1279
+ Standardizes ShareGPT and other formats to user/assistant Hugging Face format.
1280
+
1281
+ Get aliases for the system, user and assistant roles.
1282
+ These shall map to "system", "user" and "assistant" respectively.
1283
+
1284
+ aliases_for_system = ["system",],
1285
+ aliases_for_user = ["user", "human", "input",],
1286
+ aliases_for_assistant = ["gpt", "assistant", "output",],
1287
+ """
1288
+ import collections
1289
+ import itertools
1290
+
1291
+ convos = dataset[:10]["conversations"]
1292
+ uniques = collections.defaultdict(list)
1293
+ for convo in convos:
1294
+ for message in convo:
1295
+ for key, value in message.items():
1296
+ uniques[key].append(value)
1297
+ pass
1298
+
1299
+ # Must be only 2 entries
1300
+ assert(len(uniques.keys()) == 2)
1301
+
1302
+ keys = list(uniques.keys())
1303
+ length_first = len(set(uniques[keys[0]]))
1304
+ length_second = len(set(uniques[keys[1]]))
1305
+
1306
+ if length_first < length_second:
1307
+ # Role is assigned to the first element
1308
+ role_key = keys[0]
1309
+ content_key = keys[1]
1310
+ else:
1311
+ role_key = keys[1]
1312
+ content_key = keys[0]
1313
+ pass
1314
+
1315
+ # Check roles are in aliases
1316
+ all_aliases = set(aliases_for_system + aliases_for_user + aliases_for_assistant)
1317
+ roles = set(uniques[role_key])
1318
+ leftover_aliases = (all_aliases | roles) - all_aliases
1319
+ if len(leftover_aliases) != 0:
1320
+ raise TypeError(
1321
+ f"Unsloth: {list(leftover_aliases)} are not in aliases. Please update aliases."
1322
+ )
1323
+ pass
1324
+
1325
+ # Mapping for aliases
1326
+ aliases_mapping = {}
1327
+ for x in aliases_for_system: aliases_mapping[x] = "system"
1328
+ for x in aliases_for_user: aliases_mapping[x] = "user"
1329
+ for x in aliases_for_assistant: aliases_mapping[x] = "assistant"
1330
+
1331
+ def _standardize_dataset(examples):
1332
+ convos = examples["conversations"]
1333
+ all_convos = []
1334
+ for convo in convos:
1335
+ new_convo = [
1336
+ { "role" : aliases_mapping[message[role_key]], "content" : message[content_key], }
1337
+ for message in convo
1338
+ ]
1339
+ all_convos.append(new_convo)
1340
+ pass
1341
+ return { "conversations" : all_convos, }
1342
+ pass
1343
+
1344
+ return dataset.map(_standardize_dataset, batched = True, desc = "Standardizing format")
1345
+ pass
1346
+
1347
+
1348
+ def get_ollama_eos_tokens(tokenizer, extra_eos_tokens = []):
1349
+ added_tokens_decoder = tokenizer.added_tokens_decoder.values()
1350
+ added_tokens_decoder = [str(x) for x in added_tokens_decoder]
1351
+
1352
+ # Remove added_tokens_decoder duplicates
1353
+ added_tokens_decoder = list(set(added_tokens_decoder) - set(extra_eos_tokens))
1354
+
1355
+ # Remove BOS
1356
+ if getattr(tokenizer, "bos_token", None) is not None:
1357
+ added_tokens_decoder = [x for x in added_tokens_decoder if x != tokenizer.bos_token]
1358
+ pass
1359
+
1360
+ repeatted_tokens = []
1361
+ # Join all vocab
1362
+ joined_text = "\x01\x00".join(added_tokens_decoder)
1363
+ for token in added_tokens_decoder:
1364
+ n = len(token)
1365
+ repeatted_counts = joined_text.count(token[:n//2])
1366
+ # Try finding longer than 1/2 of the token in the rest
1367
+ # For eg <|reserved_special_token_0|>, <|reserved_special_token_1|>
1368
+ if repeatted_counts > 2:
1369
+ for j in range(n//2+1, n):
1370
+ if joined_text.count(token[:j]) < repeatted_counts:
1371
+ j -= 1
1372
+ # Remove repeatted tokens to reduce search space
1373
+ joined_text = joined_text.replace(token[:j], "")
1374
+ repeatted_tokens.append(token[:j])
1375
+ break
1376
+ pass
1377
+ pass
1378
+ pass
1379
+
1380
+ # Remove duplicates
1381
+ splitted = joined_text.split("\x01\x00")
1382
+ final_eos_tokens = []
1383
+ for old, new in zip(added_tokens_decoder, splitted):
1384
+ if old == new: final_eos_tokens.append(old)
1385
+ pass
1386
+ final_eos_tokens += extra_eos_tokens
1387
+ final_eos_tokens += repeatted_tokens
1388
+
1389
+ # Remove new lines, spaces and HTML tags
1390
+ filtered_eos_tokens = []
1391
+ for token in final_eos_tokens:
1392
+ if token.count("\n") == len(token): continue
1393
+ elif token.count("▁") == len(token): continue
1394
+ elif token.startswith("<") and len(token) <= 2: continue
1395
+ elif token.startswith("</") and len(token) == 3: continue
1396
+ filtered_eos_tokens.append(token)
1397
+ pass
1398
+ return filtered_eos_tokens
1399
+ pass
1400
+
1401
+
1402
+ def construct_chat_template( \
1403
+
1404
+ tokenizer = None,
1405
+
1406
+ chat_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
1407
+
1408
+ {SYSTEM}<|eot_id|><|start_header_id|>user<|end_header_id|>
1409
+
1410
+ {INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
1411
+
1412
+ {OUTPUT}<|eot_id|><|start_header_id|>user<|end_header_id|>
1413
+
1414
+ {INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
1415
+
1416
+ {OUTPUT}<|eot_id|>""",
1417
+
1418
+ default_system_message = \
1419
+ "Below are some instructions that describe some tasks. Write responses that appropriately complete each request.",
1420
+
1421
+ extra_eos_tokens = None,
1422
+ ):
1423
+ """
1424
+ Creates a Ollama modelfile and a HF Jinja template from a custom
1425
+ template. You must provide 2x examples of an input & output.
1426
+ There is an optional system message as well.
1427
+
1428
+ You must use {INPUT}, {OUTPUT} twice, and {SYSTEM} is optional.
1429
+ """
1430
+ # Strip only the left
1431
+ chat_template = chat_template.lstrip()
1432
+
1433
+ assert(tokenizer is not None)
1434
+
1435
+ if extra_eos_tokens is None: extra_eos_tokens = []
1436
+ elif type(extra_eos_tokens) is str: extra_eos_tokens = [extra_eos_tokens,]
1437
+
1438
+ vocab = tokenizer.get_vocab()
1439
+ for extra_eos in extra_eos_tokens:
1440
+ assert(type(extra_eos) is str)
1441
+ if extra_eos not in vocab:
1442
+ raise ValueError(f"Unsloth: `{extra_eos}` is not a singular token in the tokenizer.")
1443
+ pass
1444
+ pass
1445
+
1446
+ error_msg = \
1447
+ "Unsloth: Your prompt template must have 2 examples showing the user input {INPUT} "\
1448
+ "and the assistant output {OUTPUT}\n\n"\
1449
+ "For example what is not allowed is just:\n"\
1450
+ "### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n\n\n"\
1451
+ "What is required is 2x of this:\n"\
1452
+ "### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n"\
1453
+ "### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n"
1454
+
1455
+ # Check for EOS after {OUTPUT}
1456
+ if tokenizer.eos_token is not None:
1457
+ extra_eos_tokens.insert(0, tokenizer.eos_token)
1458
+ if len(extra_eos_tokens) == 0:
1459
+ raise RuntimeError(
1460
+ "Unsloth: Your tokenizer does not have an EOS token? Please provide one via extra_eos_tokens!"
1461
+ )
1462
+ pass
1463
+
1464
+ # Check tokenizer types
1465
+ tokenizer_name = tokenizer.name_or_path.lower()
1466
+ if tokenizer_name.startswith(("unsloth/llama-3-8b-instruct", "unsloth/llama-3-70b-instruct")):
1467
+ # Add <|eot_id|>
1468
+ extra_eos_tokens.append("<|eot_id|>")
1469
+ elif ("<|eot_id|>" in extra_eos_tokens or "<|eot_id|>" in chat_template) and \
1470
+ tokenizer_name.startswith(("unsloth/llama-3-8b", "unsloth/llama-3-70b")):
1471
+ # Warn
1472
+ logger.warning(
1473
+ "Unsloth: Base llama-3 models did not train <|eot_id|>.\n"\
1474
+ "Please use the instruct version or use <|end_of_text|>"
1475
+ )
1476
+ pass
1477
+ extra_eos_tokens = list(set(extra_eos_tokens))
1478
+
1479
+ count_eos = 0
1480
+ for eos in extra_eos_tokens:
1481
+ count_eos += len(re.findall(r"{OUTPUT}" + re.escape(eos), chat_template))
1482
+ pass
1483
+
1484
+ # This forces you to provide 2 input and outputs
1485
+ final_combined_check = False
1486
+
1487
+ try:
1488
+ # O(N^2) search finding 2 repeatted pieces of text
1489
+ j = len(chat_template)-1
1490
+ at_least_one = False
1491
+ while j > 0:
1492
+ found = chat_template.rfind(chat_template[j:], 0, j)
1493
+ if found == -1: break
1494
+ j -= 1
1495
+ at_least_one = True
1496
+ pass
1497
+ if j > 0: j += 1
1498
+ else: raise RuntimeError(error_msg)
1499
+
1500
+ if not at_least_one: raise RuntimeError(error_msg)
1501
+
1502
+ # Must be equivalent to left
1503
+ final_combined_check = True
1504
+
1505
+ # Repeatted text
1506
+ instruction_response = chat_template[j:]
1507
+ if instruction_response.count("{INPUT}") != 1 or instruction_response.count("{OUTPUT}") != 1:
1508
+ raise RuntimeError(error_msg)
1509
+ pass
1510
+
1511
+ # 1st System, Instruction, Output pair
1512
+ left = chat_template[:j]
1513
+ # 2nd Instruction, Output pair
1514
+ right = chat_template[j:]
1515
+
1516
+ final_combined_check = left if final_combined_check else chat_template
1517
+
1518
+ # Isolate input
1519
+ extra_eos_tokens_regex = "|".join(f"(?:{re.escape(x)})" for x in extra_eos_tokens)
1520
+ if len(extra_eos_tokens_regex) != 0:
1521
+ find_end = f"(?:{extra_eos_tokens_regex})?"
1522
+ else:
1523
+ find_end = ""
1524
+ find_end = r"\{INPUT\}[\s\n]{0,}" + find_end
1525
+ input_end = list(re.finditer(find_end, right))
1526
+ assert(len(input_end) == 1)
1527
+ input_end = input_end[0]
1528
+ input_end = input_end.span(0)[1]
1529
+ input_part = right[:input_end]
1530
+
1531
+ # Isolate output
1532
+ output_part = right[input_end:]
1533
+
1534
+ # Isolate system
1535
+ where_system = left.find(input_part)
1536
+ system_part = left[:where_system if where_system != -1 else len(left)]
1537
+
1538
+ # Check if the user provided a correct prompt
1539
+ combined = system_part + input_part + output_part
1540
+ if combined != final_combined_check:
1541
+ combined_changed = combined .replace('\n', '\\n')
1542
+ left_changed = final_combined_check.replace('\n', '\\n')
1543
+ raise RuntimeError(
1544
+ "Unsloth: The prompt template you provided isn't correct. You gave:\n"\
1545
+ f"{combined_changed}\n\n"\
1546
+ "But we require the following:\n"\
1547
+ f"{left_changed}"
1548
+ )
1549
+ pass
1550
+ except:
1551
+ ending = chat_template[chat_template.find("{OUTPUT}") + len("{OUTPUT}"):]
1552
+
1553
+ ending = re.escape(ending)
1554
+ find_text = "{INPUT}" + ending + "(.+?{OUTPUT}" + ending + ")"
1555
+ response_part = re.findall(find_text, chat_template, flags = re.DOTALL | re.MULTILINE)
1556
+ response_part = response_part[0]
1557
+
1558
+ for j in range(1, len(response_part)):
1559
+ try_find = re.escape(response_part[:j])
1560
+ try: found = next(re.finditer("(" + try_find + ").+?\{INPUT\}", chat_template, flags = re.DOTALL | re.MULTILINE))
1561
+ except: break
1562
+ pass
1563
+ separator = found.group(1)
1564
+
1565
+ response_start = chat_template.find(response_part)
1566
+ start_instruction = chat_template[:response_start].rfind(separator)
1567
+ if start_instruction == -1: start_instruction = 0
1568
+ instruction_part = chat_template[start_instruction:response_start]
1569
+
1570
+ combined = instruction_part + response_part
1571
+ where = chat_template.find(combined)
1572
+ system_part = chat_template[:where]
1573
+
1574
+ system_part, input_part, output_part = system_part, instruction_part, response_part
1575
+ pass
1576
+
1577
+ if count_eos == 0:
1578
+ logger.warning("Unsloth: We automatically added an EOS token to stop endless generations.")
1579
+ eos = extra_eos_tokens[0]
1580
+ output_part = output_part + eos
1581
+ pass
1582
+
1583
+ # Ollama modelfile parts
1584
+
1585
+ # Check bos_token is in system prompt
1586
+ ollama_system = system_part
1587
+ has_bos_token = False
1588
+ always_bos_token = False
1589
+ if tokenizer("A").input_ids[0] == getattr(tokenizer, "bos_token_id", None):
1590
+ always_bos_token = True
1591
+ if ollama_system.startswith(tokenizer.bos_token):
1592
+ has_bos_token = True
1593
+ ollama_system = ollama_system[len(tokenizer.bos_token):]
1594
+ pass
1595
+ pass
1596
+ # Check system
1597
+ if "{SYSTEM}" in ollama_system:
1598
+ system_modelfile = "{{ if .System }}" + ollama_system.replace("{SYSTEM}", "{{ .System }}") + "{{ end }}"
1599
+ else:
1600
+ system_modelfile = ollama_system
1601
+ pass
1602
+ input_modelfile = "{{ if .Prompt }}" + input_part .replace("{INPUT}", "{{ .Prompt }}") + "{{ end }}"
1603
+ output_modelfile = output_part.replace("{OUTPUT}", "{{ .Response }}")
1604
+
1605
+ # Ollama EOS
1606
+ ollama_eos = get_ollama_eos_tokens(tokenizer, extra_eos_tokens)
1607
+ ollama_eos = '\n'.join(f'PARAMETER stop "{eos}"' for eos in ollama_eos)
1608
+
1609
+ # Add temperature and min_p to counteract gibberish
1610
+ ollama_eos += "\nPARAMETER temperature 1.5\nPARAMETER min_p 0.1"
1611
+
1612
+ # Ollama modelfile
1613
+ part = '"""'
1614
+ modelfile = 'FROM {__FILE_LOCATION__}\n\n'\
1615
+ 'TEMPLATE ' + part + system_modelfile + input_modelfile + output_modelfile + \
1616
+ part + '\n\n' + ollama_eos
1617
+
1618
+ # HF Jinja Chat template
1619
+ def process(part, which, content = "message['content']"):
1620
+ if part.endswith(which):
1621
+ part = "'" + part[:part.find(which)] + f"' + {content}"
1622
+ elif part.startswith(which):
1623
+ part = f"{content} + '" + part[part.find(which):] + "'"
1624
+ else:
1625
+ part = "'" + part.replace(which, f"' + {content} + '") + "'"
1626
+ if part.startswith("'' + "): part = part[5:]
1627
+ return part
1628
+ pass
1629
+ input_jinja = process(input_part, "{INPUT}")
1630
+ output_jinja = process(output_part, "{OUTPUT}")
1631
+ pass
1632
+
1633
+ jinja_template = \
1634
+ "{% for message in loop_messages %}"\
1635
+ "{% if message['role'] == 'user' %}"\
1636
+ "{{ " + input_jinja + " }}"\
1637
+ "{% elif message['role'] == 'assistant' %}"\
1638
+ "{{ " + output_jinja + " }}"\
1639
+ "{% else %}"\
1640
+ "{{ raise_exception('Only user and assistant roles are supported!') }}"\
1641
+ "{% endif %}"\
1642
+ "{% endfor %}"\
1643
+ "{% if add_generation_prompt %}"\
1644
+ "{{ '" + output_part[:output_part.find("{OUTPUT}")] + "' }}"\
1645
+ "{% endif %}"
1646
+ pass
1647
+
1648
+ # Now add system prompt to jinja
1649
+ if len(system_part) != 0:
1650
+ partial_system = process(system_part, "{SYSTEM}", "messages[0]['content']")
1651
+ partial_system = partial_system.replace("{SYSTEM}", "")
1652
+
1653
+ if "{SYSTEM}" in partial_system:
1654
+ if default_system_message is None:
1655
+ raise RuntimeError("Unsloth: Please specify a default system message!")
1656
+ pass
1657
+
1658
+ # Separate the BOS
1659
+ if has_bos_token:
1660
+ partial_system = partial_system.replace(tokenizer.bos_token, "", 1)
1661
+ system_part = system_part .replace(tokenizer.bos_token, "", 1)
1662
+ pass
1663
+
1664
+ partial_system = \
1665
+ "{% if messages[0]['role'] == 'system' %}"\
1666
+ "{{ " + partial_system + " }}"\
1667
+ "{% set loop_messages = messages[1:] %}"
1668
+ if default_system_message is not None:
1669
+ full_system = system_part.replace("{SYSTEM}", default_system_message)
1670
+ if "{SYSTEM}" in system_part:
1671
+ modelfile += '\nSYSTEM "' + default_system_message + '"'
1672
+ pass
1673
+ partial_system += "{% else %}"\
1674
+ "{{ '" + full_system + "' }}"\
1675
+ "{% set loop_messages = messages %}"\
1676
+ "{% endif %}"
1677
+ else:
1678
+ partial_system += "{% endif %}"
1679
+ pass
1680
+
1681
+ jinja_template = partial_system + jinja_template
1682
+
1683
+ if has_bos_token:
1684
+ jinja_template = "{{ bos_token }}" + jinja_template
1685
+ pass
1686
+
1687
+ # Fix missing loop_messages
1688
+ if "{% set loop_messages = messages %}" not in jinja_template:
1689
+ jinja_template = jinja_template.replace(
1690
+ "{% for message in loop_messages %}",
1691
+ "{% for message in messages %}",
1692
+ 1, # Only replace the first one
1693
+ )
1694
+ pass
1695
+
1696
+ # Check if system part is the same!
1697
+ jinja_template = re.sub(
1698
+ r"\{\% if messages\[0\]\['role'\] \=\= 'system' \%\}\{\{ '(.+?)' \}\}"\
1699
+ r"\{\% set loop\_messages \= messages\[1\:\] \%\}"\
1700
+ r"\{\% else \%\}\{\{ '\1' \}\}\{\% set loop\_messages \= messages \%\}\{\% endif \%\}"\
1701
+ r"\{\% for message in loop\_messages \%\}",
1702
+ r"{{ '\1' }}{% for message in messages %}",
1703
+ jinja_template, flags = re.MULTILINE | re.DOTALL,
1704
+ )
1705
+
1706
+ # Check jinja tempate for bos
1707
+ if always_bos_token:
1708
+ if not jinja_template.startswith(("{{ bos_token }}", "{{- bos_token }}")):
1709
+ jinja_template = "{{ bos_token }}" + jinja_template
1710
+ pass
1711
+
1712
+ # Get instruction and output parts for train_on_inputs = False
1713
+ input_part = input_part [:input_part .find("{INPUT}")]
1714
+ output_part = output_part[:output_part.find("{OUTPUT}")]
1715
+ return modelfile, jinja_template, input_part, output_part
1716
+ pass
1717
+
1718
+
1719
+ def test_construct_chat_template():
1720
+ token = "hf_"
1721
+ from transformers import AutoTokenizer
1722
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", token = token)
1723
+
1724
+ chat_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
1725
+
1726
+ {SYSTEM}<|eot_id|><|start_header_id|>user<|end_header_id|>
1727
+
1728
+ {INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
1729
+
1730
+ {OUTPUT}<|eot_id|><|start_header_id|>user<|end_header_id|>
1731
+
1732
+ {INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
1733
+
1734
+ {OUTPUT}<|eot_id|>"""
1735
+
1736
+ default_system_message = \
1737
+ "Below are some instructions that describe some tasks. Write responses that appropriately complete each request."
1738
+
1739
+ extra_eos_tokens = None
1740
+
1741
+ modelfile, jinja_template, _, _ = construct_chat_template(
1742
+ tokenizer = tokenizer,
1743
+ chat_template = chat_template,
1744
+ extra_eos_tokens = extra_eos_tokens,
1745
+ )
1746
+
1747
+ messages = [
1748
+ {"role": "system", "content": "You are an assistant"},
1749
+ {"role": "user", "content": "What is 2+2?"},
1750
+ {"role": "assistant", "content": "It's 4."},
1751
+ {"role": "user", "content": "Ok!"},
1752
+ {"role": "assistant", "content": "Anything else?"},
1753
+ {"role": "user", "content": "What's 2x2?"},
1754
+ ]
1755
+ correct_output = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
1756
+
1757
+ tokenizer.chat_template = jinja_template
1758
+ new_output = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
1759
+ assert(correct_output == new_output)
1760
+ pass
1761
+ pass
1762
+
1763
+
1764
+ def apply_chat_template( \
1765
+
1766
+ dataset,
1767
+ tokenizer = None,
1768
+
1769
+ chat_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
1770
+
1771
+ {SYSTEM}<|eot_id|><|start_header_id|>user<|end_header_id|>
1772
+
1773
+ {INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
1774
+
1775
+ {OUTPUT}<|eot_id|><|start_header_id|>user<|end_header_id|>
1776
+
1777
+ {INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
1778
+
1779
+ {OUTPUT}<|eot_id|>""",
1780
+
1781
+ default_system_message = \
1782
+ "Below are some instructions that describe some tasks. Write responses that appropriately complete each request.",
1783
+
1784
+ extra_eos_tokens = None,
1785
+
1786
+ ):
1787
+ """
1788
+ Creates a Ollama modelfile and a HF Jinja template from a custom
1789
+ template. You must provide 2x examples of an input & output.
1790
+ There is an optional system message as well.
1791
+
1792
+ You must use {INPUT}, {OUTPUT} twice, and {SYSTEM} is optional.
1793
+ """
1794
+ modelfile, jinja_template, input_part, output_part = construct_chat_template(
1795
+ tokenizer = tokenizer,
1796
+ chat_template = chat_template,
1797
+ default_system_message = default_system_message,
1798
+ extra_eos_tokens = extra_eos_tokens,
1799
+ )
1800
+ def formatting_prompts_func(examples):
1801
+ convos = examples["conversations"]
1802
+ texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
1803
+ return { "text" : texts, }
1804
+ pass
1805
+
1806
+ tokenizer.chat_template = jinja_template
1807
+ tokenizer._ollama_modelfile = modelfile
1808
+ tokenizer._unsloth_input_part = input_part
1809
+ tokenizer._unsloth_output_part = output_part
1810
+
1811
+ return dataset.map(formatting_prompts_func, batched = True,)
1812
+ pass
1813
+
1814
+
1815
+ # From https://www.geeksforgeeks.org/longest-common-substring-array-strings/
1816
+ # Longest Common Substring in an Array of Strings
1817
+ def _longest_common_substring(arr):
1818
+ n = len(arr)
1819
+ s = arr[0]
1820
+ l = len(s)
1821
+ res = ""
1822
+ for i in range(l):
1823
+ for j in range(i + 1, l + 1):
1824
+ stem = s[i:j]
1825
+ k = 1
1826
+ for k in range(1, n):
1827
+ if stem not in arr[k]:
1828
+ break
1829
+ if (k + 1 == n and len(res) < len(stem)):
1830
+ res = stem
1831
+ return res
1832
+ pass
1833
+
1834
+
1835
+ def _find_common_token_ids(component, tokenizer):
1836
+ """
1837
+ \n### User:\n\n
1838
+ \n\n### User:\n\n
1839
+ etc
1840
+ we need to find the middle most repeatted part.
1841
+ Tokenizers can tokenize newlines or spaces as 1 token!
1842
+ """
1843
+ right_text = ""
1844
+ if component.endswith (" "): right_text = " "
1845
+ elif component.endswith("\n"): right_text = "\n"
1846
+ left_text = ""
1847
+ if component.startswith (" "): left_text = " "
1848
+ elif component.startswith("\n"): left_text = "\n"
1849
+ stripped = component.strip()
1850
+
1851
+ # Add current pieces and also newlines
1852
+ all_input_ids = []
1853
+ for left in range(3):
1854
+ for right in range(3):
1855
+ x = left*left_text + stripped + right*right_text
1856
+ x = tokenizer(x, add_special_tokens = False).input_ids
1857
+ all_input_ids.append(x)
1858
+
1859
+ x = left*"\n" + stripped + right*"\n"
1860
+ x = tokenizer(x, add_special_tokens = False).input_ids
1861
+ all_input_ids.append(x)
1862
+ pass
1863
+ pass
1864
+ substring = _longest_common_substring([str(x + [0]) for x in all_input_ids])
1865
+ substring = substring.split(", ")[:-1]
1866
+ substring = [int(x) for x in substring]
1867
+
1868
+ # Also get rest of tokenized string
1869
+ original = tokenizer(component, add_special_tokens = False).input_ids
1870
+ # Get optional left and right
1871
+ for j in range(len(original)):
1872
+ if original[j : j + len(substring)] == substring: break
1873
+ optional_left = original[:j]
1874
+ optional_right = original[j+len(substring):]
1875
+ return substring, optional_left, optional_right
1876
+ pass
1877
+
1878
+
1879
+ def train_on_responses_only(
1880
+ trainer,
1881
+ instruction_part = None,
1882
+ response_part = None,
1883
+ ):
1884
+ """
1885
+ Trains only on responses and not on the instruction by masking out
1886
+ the labels with -100 for the instruction part.
1887
+ """
1888
+ tokenizer = trainer.tokenizer
1889
+
1890
+ if not hasattr(tokenizer, "_unsloth_input_part") or \
1891
+ not hasattr(tokenizer, "_unsloth_output_part"):
1892
+
1893
+ if instruction_part is None or response_part is None:
1894
+ raise ValueError("Unsloth: instruction_part and response_part must be given!")
1895
+ pass
1896
+ elif (instruction_part is not None or response_part is not None) and \
1897
+ (hasattr(tokenizer, "_unsloth_input_part") or hasattr(tokenizer, "_unsloth_output_part")):
1898
+
1899
+ raise ValueError("Unsloth: Your tokenizer already has instruction and response parts set - do not give custom ones!")
1900
+ else:
1901
+ instruction_part = tokenizer._unsloth_input_part
1902
+ response_part = tokenizer._unsloth_output_part
1903
+ pass
1904
+
1905
+ # Get most common tokens since tokenizers can tokenize stuff differently!
1906
+ Q_must, Q_left, Q_right = _find_common_token_ids(instruction_part, tokenizer)
1907
+ A_must, A_left, A_right = _find_common_token_ids(response_part, tokenizer)
1908
+
1909
+ # Store some temporary stuff
1910
+ A_first = A_must[0]
1911
+ len_A_must = len(A_must)
1912
+ A_left_reversed = A_left[::-1]
1913
+ A_right_forward = A_right
1914
+
1915
+ Q_first = Q_must[0]
1916
+ len_Q_must = len(Q_must)
1917
+ Q_left_reversed = Q_left[::-1]
1918
+ Q_right_forward = Q_right
1919
+
1920
+ def _train_on_responses_only(examples):
1921
+ input_ids_ = examples["input_ids"]
1922
+ all_labels = []
1923
+
1924
+ for input_ids in input_ids_:
1925
+ n = len(input_ids)
1926
+ labels = [-100] * n
1927
+ n_minus_1 = n - 1
1928
+ j = 0
1929
+ while j < n:
1930
+ # Find <assistant>
1931
+ if (input_ids[j] == A_first) and \
1932
+ (input_ids[j : (k := j + len_A_must)] == A_must):
1933
+
1934
+ # Now backtrack to get previous optional tokens
1935
+ for optional_left in A_left_reversed:
1936
+ if j < 1: break
1937
+ if optional_left == input_ids[j-1]: j -= 1
1938
+ else: break
1939
+ pass
1940
+ # And forwards look as well
1941
+ for optional_right in A_right_forward:
1942
+ if k >= n_minus_1: break
1943
+ if optional_right == input_ids[k+1]: k += 1
1944
+ else: break
1945
+ pass
1946
+ # assistant_j = j
1947
+ assistant_k = k
1948
+
1949
+ j = assistant_k
1950
+ # Given <assistant>, now find next user
1951
+ while j < n:
1952
+ # Find <user>
1953
+ # Also accept last final item if assistant is the last turn
1954
+ if (j == n_minus_1) or \
1955
+ ((input_ids[j] == Q_first) and \
1956
+ (input_ids[j : (k := j + len_Q_must)] == Q_must)):
1957
+
1958
+ # Now backtrack to get previous optional tokens
1959
+ for optional_left in Q_left_reversed:
1960
+ if j < 1: break
1961
+ if optional_left == input_ids[j-1]: j -= 1
1962
+ else: break
1963
+ pass
1964
+ # And forwards look as well
1965
+ for optional_right in Q_right_forward:
1966
+ if k >= n_minus_1: break
1967
+ if optional_right == input_ids[k+1]: k += 1
1968
+ else: break
1969
+ pass
1970
+ user_j = j
1971
+ # Account for last item
1972
+ if user_j != n_minus_1:
1973
+ # user_k = k
1974
+ # j = user_k
1975
+ j = k
1976
+ else:
1977
+ user_j = n
1978
+ k = n
1979
+ pass
1980
+ # Now copy input_ids to labels
1981
+ labels[assistant_k : user_j] = input_ids[assistant_k : user_j]
1982
+ # print(assistant_j, assistant_k, user_j, user_k)
1983
+ break
1984
+ pass
1985
+ j += 1
1986
+ pass
1987
+ pass
1988
+ j += 1
1989
+ pass
1990
+ all_labels.append(labels)
1991
+ pass
1992
+ return { "labels" : all_labels }
1993
+ pass
1994
+
1995
+ if hasattr(trainer, "train_dataset") and trainer.train_dataset is not None:
1996
+ trainer.train_dataset = trainer.train_dataset.map(_train_on_responses_only, batched = True)
1997
+ if hasattr(trainer, "eval_dataset") and trainer.eval_dataset is not None:
1998
+ trainer.eval_dataset = trainer.eval_dataset.map(_train_on_responses_only, batched = True)
1999
+ return trainer
2000
+ pass
2001
+
2002
+
2003
+ def create_stopping_criteria(tokenizer, stop_word = "eos_token"):
2004
+ class StoppingCriteriaSub(StoppingCriteria):
2005
+ __slots__ = "stop_token", "single_match", "length",
2006
+
2007
+ def __init__(self, stops = "eos_token", device = "cuda", encounters = 1):
2008
+ super().__init__()
2009
+ if stops == "eos_token":
2010
+ self.stop_token = torch.tensor(tokenizer.eos_token_id, device = "cuda")
2011
+ self.length = 1
2012
+ else:
2013
+ self.stop_token = tokenizer(["\n" + stops], add_special_tokens = False, return_tensors = "pt")
2014
+ self.stop_token = self.stop_token.input_ids.ravel()[1:].to("cuda")
2015
+ self.length = self.stop_token.shape[0]
2016
+ pass
2017
+ self.single_match = self.length == 1
2018
+ pass
2019
+
2020
+ def __call__(self, input_ids: LongTensor, scores: FloatTensor) -> bool:
2021
+ input_ids = input_ids.ravel()
2022
+ last_token = input_ids[-1]
2023
+ if self.single_match and (last_token == self.stop_token): return True
2024
+
2025
+ if input_ids.shape[0] >= self.length and \
2026
+ (input_ids[-self.length:] == self.stop_token).all(): return True
2027
+ return False
2028
+ pass
2029
+ pass
2030
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops = stop_word)])
2031
+ return stopping_criteria
2032
+ pass
2033
+
2034
+
2035
+ def test_chat_templates():
2036
+ messages = [
2037
+ {"role": "system","content": " You are a friendly chatbot.",},
2038
+ {"role": "user", "content": "What is 2+2?"},
2039
+ {"role": "assistant", "content": "It's 4."},
2040
+ {"role": "user", "content": " But 2+2 is equal to 5. "},
2041
+ {"role": "assistant", "content": "No I'm sure its 4."},
2042
+ {"role": "user", "content": " No it's 100% 5! "},
2043
+ ]
2044
+
2045
+ # Zephyr
2046
+ from transformers import AutoTokenizer
2047
+ template = zephyr_template
2048
+ correct_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
2049
+ correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
2050
+ correct_tokenizer.chat_template = template
2051
+ our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
2052
+ assert(correct_prompt == our_prompt)
2053
+
2054
+ # Chatml
2055
+ template = chatml_template
2056
+ correct_tokenizer = AutoTokenizer.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B")
2057
+ correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
2058
+ correct_tokenizer.chat_template = template
2059
+ our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
2060
+ assert(correct_prompt == our_prompt)
2061
+
2062
+ # Mistral
2063
+ template = mistral_template
2064
+ correct_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
2065
+ correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
2066
+ correct_tokenizer.chat_template = template
2067
+ our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
2068
+ assert(correct_prompt == our_prompt)
2069
+
2070
+ # Llama
2071
+ template = llama_template
2072
+ correct_tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-2-7b-chat")
2073
+ correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
2074
+ correct_tokenizer.chat_template = template
2075
+ our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
2076
+ assert(correct_prompt == our_prompt)
2077
+
2078
+ # Vicuna
2079
+ try:
2080
+ from fastchat.conversation import get_conv_template
2081
+ except:
2082
+ os.system("pip -qqq install git+https://github.com/lm-sys/FastChat.git")
2083
+ from fastchat.conversation import get_conv_template
2084
+ correct_prompt = get_conv_template("vicuna_v1.1")
2085
+ for j in range(len(messages)-1):
2086
+ correct_prompt.append_message(correct_prompt.roles[j%2==1], messages[j+1]["content"])
2087
+ correct_prompt.append_message(correct_prompt.roles[1], "")
2088
+ correct_prompt = tokenizer.bos_token + correct_prompt.get_prompt()
2089
+
2090
+ template = vicuna_template
2091
+ correct_tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
2092
+ correct_tokenizer.chat_template = template
2093
+ our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
2094
+ assert(correct_prompt == our_prompt)
2095
+
2096
+ try:
2097
+ from fastchat.conversation import get_conv_template
2098
+ except:
2099
+ os.system("pip -qqq install git+https://github.com/lm-sys/FastChat.git")
2100
+ from fastchat.conversation import get_conv_template
2101
+ correct_prompt = get_conv_template("zero_shot")
2102
+ for j in range(len(messages)-1):
2103
+ correct_prompt.append_message(correct_prompt.roles[j%2==1], messages[j+1]["content"])
2104
+ correct_prompt.append_message(correct_prompt.roles[1], "")
2105
+ correct_prompt = tokenizer.bos_token + correct_prompt.get_prompt()
2106
+
2107
+ template = vicuna_old_template
2108
+ correct_tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
2109
+ correct_tokenizer.chat_template = template
2110
+ our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
2111
+ # We add </s> ourselves
2112
+ assert(correct_prompt == our_prompt.replace("</s>", ""))
2113
+
2114
+ # Gemma
2115
+ correct_tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-7b-it")
2116
+ correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
2117
+ correct_tokenizer.chat_template = gemma_template
2118
+ our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
2119
+ assert(our_prompt == correct_prompt)
2120
+
2121
+ # Llama-3
2122
+ template = llama3_template
2123
+ correct_tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-3-8b-Instruct")
2124
+ correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
2125
+ correct_tokenizer.chat_template = template
2126
+ our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
2127
+ assert(correct_prompt == our_prompt)
2128
+
2129
+ # Phi-3
2130
+ template = phi3_template
2131
+ correct_tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
2132
+ correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
2133
+ correct_tokenizer.chat_template = template
2134
+ our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
2135
+ assert(correct_prompt == our_prompt)
2136
+ pass
2137
+
2138
+
2139
+ def test_hf_gguf_equivalence(tokenizer, gguf_model = "./model-unsloth.F16.gguf"):
2140
+ """
2141
+ Carefully checks the output of GGUF's tokenization and HF.
2142
+ Can catch all tokenization bugs.
2143
+ """
2144
+ import subprocess
2145
+ import re
2146
+ messages = [
2147
+ {"role": "user", "content": "What is 2+2?"},
2148
+ {"role": "assistant", "content": "It's 4."},
2149
+ {"role": "user", "content": " But 2+2 is equal to 5. "},
2150
+ {"role": "assistant", "content": "No I'm sure its 4."},
2151
+ {"role": "user", "content": " No it's 100% 5! "},
2152
+ ]
2153
+
2154
+ prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
2155
+
2156
+ ### Instruction:
2157
+ {}
2158
+
2159
+ ### Input:
2160
+ {}
2161
+
2162
+ ### Response:
2163
+ {}""".format(
2164
+ "Describe the city given eloquently.", # instruction
2165
+ "The lost city of Atlantis.", # input
2166
+ "", # output - leave this blank for generation!
2167
+ )
2168
+ prompts = [ prompt, ]
2169
+
2170
+ if tokenizer.chat_template is not None:
2171
+ prompt = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
2172
+ prompt = prompt.replace("'", "") # Subprocess does not like ''
2173
+ prompt = remove_special_tokens(tokenizer, prompt)
2174
+ prompts.append(prompt)
2175
+ pass
2176
+
2177
+ for prompt in prompts:
2178
+ command = f"./llama.cpp/llama-cli -m {gguf_model} -n 0 --temp 0.0 --verbose-prompt "\
2179
+ f"--check-tensors -p '{prompt}'"
2180
+
2181
+ datas = []
2182
+ with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp:
2183
+ for line in sp.stdout:
2184
+ datas.append(line.decode("utf-8", errors = "replace"))
2185
+ pass
2186
+ gguf_tokens = "".join(datas)
2187
+
2188
+ # Now extract GGUF tokenization attempt
2189
+ gguf_tokenized = re.findall("([\d]{1,}) \-\> \'([^\']{1,})\'", gguf_tokens, flags = re.MULTILINE)
2190
+ gguf_tokenized = [(int(x[0]), x[1],) for x in gguf_tokenized]
2191
+ input_ids = tokenizer(prompt).input_ids
2192
+
2193
+ tokens = tokenizer.batch_decode(input_ids)
2194
+ hf_tokenized = list(zip(input_ids, tokens))
2195
+
2196
+ # Compare to Huggingface
2197
+ for j, (hf_token, gguf_token) in enumerate(zip(hf_tokenized, gguf_tokenized)):
2198
+ if (hf_token[0] != gguf_token[0]):
2199
+ print("Failed GGUF != HF at", j)
2200
+ print("HF =", hf_token)
2201
+ print("GGUF =", gguf_token)
2202
+ print(hf_tokenized)
2203
+ print()
2204
+ print(gguf_tokenized)
2205
+ print()
2206
+ raise RuntimeError("Failed comparing GGUF to HF.")
2207
+ pass
2208
+ pass
2209
+ return True
2210
+ pass
unsloth-main/unsloth-main/unsloth/kernels/__init__.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .cross_entropy_loss import (
16
+ fast_cross_entropy_loss,
17
+ patch_llama_for_causal_lm,
18
+ unpatch_llama_for_causal_lm,
19
+ )
20
+ from .rms_layernorm import (
21
+ fast_rms_layernorm,
22
+ patch_rms_layernorm,
23
+ unpatch_rms_layernorm,
24
+ )
25
+ from .layernorm import (
26
+ fast_layernorm,
27
+ patch_layernorm,
28
+ unpatch_layernorm,
29
+ )
30
+ from .rope_embedding import fast_rope_embedding, inplace_rope_embedding
31
+ from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
32
+ from .geglu import (
33
+ geglu_exact_forward_kernel,
34
+ geglu_exact_backward_kernel,
35
+ geglu_approx_forward_kernel,
36
+ geglu_approx_backward_kernel,
37
+ )
38
+ from .fast_lora import (
39
+ get_lora_parameters,
40
+ get_lora_parameters_bias,
41
+ apply_lora_mlp_swiglu,
42
+ apply_lora_mlp_geglu_exact,
43
+ apply_lora_mlp_geglu_approx,
44
+ apply_lora_qkv,
45
+ apply_lora_o,
46
+ )
47
+ from .utils import fast_dequantize, fast_gemv, QUANT_STATE, fast_linear_forward, matmul_lora
48
+
49
+ from .flex_attention import (
50
+ HAS_FLEX_ATTENTION,
51
+ slow_attention_softcapping,
52
+ slow_inference_attention_softcapping,
53
+ create_flex_attention_causal_mask,
54
+ create_flex_attention_sliding_window_mask,
55
+ )
56
+
57
+ try:
58
+ print("🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.")
59
+ except:
60
+ print("Unsloth: Will patch your computer to enable 2x faster free finetuning.")
61
+ pass
unsloth-main/unsloth-main/unsloth/kernels/cross_entropy_loss.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import triton
16
+ import triton.language as tl
17
+ import torch
18
+ from .utils import calculate_settings, MAX_FUSED_SIZE, triton_tanh
19
+ from transformers.models.llama.modeling_llama import logger
20
+
21
+
22
+ @triton.heuristics({
23
+ "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ],
24
+ "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"],
25
+ })
26
+ @triton.jit
27
+ def _cross_entropy_forward(
28
+ logits_ptr, logits_row_stride,
29
+ loss_ptr,
30
+ logsumexp_ptr,
31
+ labels_ptr,
32
+ VOCAB_SIZE : tl.constexpr,
33
+ BLOCK_SIZE : tl.constexpr,
34
+ DO_SOFTCAPPING : tl.constexpr,
35
+ SOFTCAP : tl.constexpr,
36
+ DO_LOGIT_SCALING: tl.constexpr,
37
+ LOGIT_SCALE : tl.constexpr,
38
+ ):
39
+ """
40
+ Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
41
+ Pi = exp(xi) / sum(exp(xi))
42
+ CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
43
+ = -y [ x - log[sum(exp(x))] ]
44
+ = y * (log[sum(exp(x))] - x)
45
+ If y == 0: CE_i = 0
46
+ If y == 1: CE_i = logsumexp - x
47
+
48
+ logsumexp is also stable
49
+ Take y = log[sum(exp(x))]
50
+ exp(y) = sum(exp(x))
51
+ exp(y) = sum(exp(x - c)*exp(c)) Since e^(x-c)*e^c = e^x
52
+ exp(y) = exp(c)*sum(exp(x - c))
53
+ y = log(exp(c)*sum(exp(x - c)))
54
+ y = c + log[sum(exp(x - c))]
55
+ This means we can set c = max(x) to make sure
56
+ exp(x - c) always is exp(x - max(x)).
57
+ This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1.
58
+ """
59
+ row_idx = tl.program_id(0)
60
+ logits_ptr += row_idx * logits_row_stride.to(tl.int64)
61
+ loss_ptr += row_idx
62
+ logsumexp_ptr += row_idx
63
+ labels_ptr += row_idx
64
+
65
+ col_offsets = tl.arange(0, BLOCK_SIZE)
66
+ mask = col_offsets < VOCAB_SIZE
67
+
68
+ label_idx = tl.load(labels_ptr).to(tl.int32)
69
+ logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
70
+
71
+ # Go logit scaling for Cohere: t * x
72
+ if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
73
+ # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
74
+ if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
75
+
76
+ logits = logits.to(tl.float32)
77
+ c = tl.max(logits, 0)
78
+ logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
79
+
80
+ if label_idx != -100:
81
+ x = tl.load(logits_ptr + label_idx)
82
+ # Go logit scaling for Cohere: t * x
83
+ if DO_LOGIT_SCALING: x = LOGIT_SCALE * x
84
+ # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
85
+ if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
86
+ loss = logsumexp - x.to(tl.float32)
87
+ else:
88
+ loss = 0.0
89
+ tl.store(logsumexp_ptr, logsumexp)
90
+ tl.store(loss_ptr, loss)
91
+ pass
92
+
93
+
94
+ @triton.heuristics({
95
+ "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ],
96
+ "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"],
97
+ })
98
+ @triton.jit
99
+ def _chunked_cross_entropy_forward(
100
+ logits_ptr, logits_row_stride,
101
+ loss_ptr,
102
+ logsumexp_ptr,
103
+ labels_ptr,
104
+ VOCAB_SIZE : tl.constexpr,
105
+ N_CHUNKS : tl.constexpr,
106
+ BLOCK_SIZE : tl.constexpr,
107
+ DO_SOFTCAPPING : tl.constexpr,
108
+ SOFTCAP : tl.constexpr,
109
+ DO_LOGIT_SCALING: tl.constexpr,
110
+ LOGIT_SCALE : tl.constexpr,
111
+ ):
112
+ """
113
+ 256K vocab divided in 4 chunks
114
+
115
+ |-65536-| |-65536-| |-65536-| |-65536-|
116
+ |-------| |-------| |-------| |-------|
117
+ |-------| |-------| |-------| |-------|
118
+
119
+ If y == 0: CE_i = 0
120
+ If y == 1: CE_i = logsumexp - x
121
+
122
+ Notice we can do logsumexp for each chunk and then
123
+ logsumexp[chunk_sum(logsumexp)] == logsumexp
124
+
125
+ chunk_sum = log[chunk_sum(logsumexp)]
126
+ = log[exp(logsumexp(a)) + ... + exp(logsumexp(z))]
127
+ = log[exp(log[sum(exp(a))]) + ... + exp(log[sum(exp(z))])]
128
+ = log[sum(exp(a)) + ... + sum(exp(z))]
129
+ = logsumexp(x)
130
+
131
+ This means we can perform a logsumexp for each chunk, then do a
132
+ final logsumexp reduction!
133
+
134
+ Ie do: logsumexp(chunked_logsumexp) - x
135
+ """
136
+ row_idx = tl.program_id(0)
137
+ chunk_idx = tl.program_id(1)
138
+ logits_ptr += row_idx * logits_row_stride.to(tl.int64)
139
+ loss_ptr += row_idx
140
+ logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx
141
+ labels_ptr += row_idx
142
+
143
+ col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
144
+ mask = col_offsets < VOCAB_SIZE
145
+
146
+ label_idx = tl.load(labels_ptr).to(tl.int32)
147
+ logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
148
+
149
+ # Go logit scaling for Cohere: t * x
150
+ if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
151
+ # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
152
+ if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
153
+
154
+ logits = logits.to(tl.float32)
155
+ c = tl.max(logits, 0)
156
+ logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
157
+
158
+ if chunk_idx == 0:
159
+ # logsumexp(chunked_logsumexp) - x
160
+ # Do the -x separately
161
+ if label_idx != -100:
162
+ x = tl.load(logits_ptr + label_idx).to(tl.float32)
163
+ # Go logit scaling for Cohere: t * x
164
+ if DO_LOGIT_SCALING: x = LOGIT_SCALE * x
165
+ # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
166
+ if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
167
+ loss = -1.0 * x.to(tl.float32)
168
+ else:
169
+ loss = 0.0
170
+ tl.store(loss_ptr, loss)
171
+ pass
172
+ tl.store(logsumexp_ptr, logsumexp)
173
+ pass
174
+
175
+
176
+ @triton.heuristics({
177
+ "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ],
178
+ "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"],
179
+ })
180
+ @triton.jit
181
+ def _cross_entropy_backward(
182
+ logits_ptr, logits_row_stride,
183
+ dloss_ptr, dloss_row_stride,
184
+ logsumexp_ptr,
185
+ labels_ptr,
186
+ VOCAB_SIZE : tl.constexpr,
187
+ BLOCK_SIZE : tl.constexpr,
188
+ DO_SOFTCAPPING : tl.constexpr,
189
+ SOFTCAP : tl.constexpr,
190
+ DO_LOGIT_SCALING: tl.constexpr,
191
+ LOGIT_SCALE : tl.constexpr,
192
+ ):
193
+ """
194
+ CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
195
+ dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
196
+
197
+ From https://en.wikipedia.org/wiki/LogSumExp
198
+ d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
199
+
200
+ dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
201
+ dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
202
+ dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
203
+
204
+ If y == 0: dC/dx = 0
205
+ If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
206
+ If y == 1 and x != label: dC/dx = exp[x - logsumexp]
207
+ """
208
+ row_idx = tl.program_id(0)
209
+ block_idx = tl.program_id(1)
210
+
211
+ logits_ptr += row_idx * logits_row_stride.to(tl.int64)
212
+ dloss_ptr += row_idx * dloss_row_stride
213
+ col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
214
+ mask = col_offsets < VOCAB_SIZE
215
+ label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
216
+
217
+ if label_idx != -100:
218
+ dloss = tl.load(dloss_ptr)
219
+ else:
220
+ dloss = 0.0
221
+
222
+ x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
223
+
224
+ # Do logit scaling for Cohere
225
+ if DO_LOGIT_SCALING:
226
+ # d/dx [s * x] = s
227
+ x = x * LOGIT_SCALE
228
+ pass
229
+
230
+ # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
231
+ if DO_SOFTCAPPING:
232
+ # d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
233
+ partial = triton_tanh(x / SOFTCAP)
234
+ x = SOFTCAP * partial
235
+ pass
236
+
237
+ logsumexp = tl.load(logsumexp_ptr + row_idx)
238
+ y = tl.exp(x.to(tl.float32) - logsumexp)
239
+ y = tl.where(
240
+ col_offsets == label_idx,
241
+ y - 1.0, # exp(x - logsumexp) - 1
242
+ y, # exp(x - logsumexp)
243
+ )
244
+
245
+ if DO_LOGIT_SCALING:
246
+ # d/dx [s * x] = s
247
+ y = y * LOGIT_SCALE
248
+ pass
249
+
250
+ if DO_SOFTCAPPING:
251
+ # d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
252
+ y = y * (1.0 - partial*partial)
253
+ pass
254
+
255
+ # If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.
256
+ tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)
257
+ pass
258
+
259
+
260
+ MAX_FUSED_SIZE = 65536 # 2**16
261
+
262
+ class Fast_CrossEntropyLoss(torch.autograd.Function):
263
+ @staticmethod
264
+ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0):
265
+ n_rows, vocab_size = logits.shape
266
+
267
+ div, mod = divmod(vocab_size, MAX_FUSED_SIZE)
268
+ n_chunks = div + (mod != 0)
269
+ losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
270
+
271
+ DO_SOFTCAPPING = (logit_softcapping != 0)
272
+ DO_LOGIT_SCALING = (logit_scaling != 0)
273
+
274
+ if n_chunks == 1:
275
+ # For small vocabs <= 65336 like Llama, Mistral
276
+ BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
277
+ logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
278
+
279
+ _cross_entropy_forward[(n_rows,)](
280
+ logits, logits.stride(0),
281
+ losses,
282
+ logsumexp,
283
+ labels,
284
+ VOCAB_SIZE = vocab_size,
285
+ BLOCK_SIZE = BLOCK_SIZE,
286
+ DO_SOFTCAPPING = DO_SOFTCAPPING,
287
+ SOFTCAP = logit_softcapping,
288
+ DO_LOGIT_SCALING = DO_LOGIT_SCALING,
289
+ LOGIT_SCALE = logit_scaling,
290
+ num_warps = num_warps,
291
+ )
292
+ else:
293
+ # For large vocabs > 65336 like Gemma 256K
294
+ logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = "cuda:0")
295
+
296
+ _chunked_cross_entropy_forward[(n_rows, n_chunks,)](
297
+ logits, logits.stride(0),
298
+ losses,
299
+ logsumexp,
300
+ labels,
301
+ VOCAB_SIZE = vocab_size,
302
+ N_CHUNKS = n_chunks,
303
+ BLOCK_SIZE = MAX_FUSED_SIZE,
304
+ DO_SOFTCAPPING = DO_SOFTCAPPING,
305
+ SOFTCAP = logit_softcapping,
306
+ DO_LOGIT_SCALING = DO_LOGIT_SCALING,
307
+ LOGIT_SCALE = logit_scaling,
308
+ num_warps = 32,
309
+ )
310
+ # logsumexp(chunked_logsumexp) - x
311
+ # Do the -x separately
312
+ logsumexp = torch.logsumexp(logsumexp, dim = 1) # Row sum
313
+ losses += logsumexp
314
+ losses.masked_fill_(labels == -100, 0) # Don't forget to mask padding out!
315
+ pass
316
+
317
+ ctx.save_for_backward(logits, logsumexp, labels)
318
+ ctx.DO_SOFTCAPPING = DO_SOFTCAPPING
319
+ ctx.logit_softcapping = logit_softcapping
320
+ ctx.DO_LOGIT_SCALING = DO_LOGIT_SCALING
321
+ ctx.logit_scaling = logit_scaling
322
+ return losses
323
+ pass
324
+
325
+ @staticmethod
326
+ def backward(ctx, dlosses):
327
+ logits, logsumexp, labels = ctx.saved_tensors
328
+ n_rows, vocab_size = logits.shape
329
+
330
+ BLOCK_SIZE = 4096
331
+ div, mod = divmod(vocab_size, BLOCK_SIZE)
332
+ n_blocks = div + (mod != 0)
333
+
334
+ _cross_entropy_backward[(n_rows, n_blocks,)](
335
+ logits, logits.stride(0),
336
+ dlosses, dlosses.stride(0),
337
+ logsumexp,
338
+ labels,
339
+ VOCAB_SIZE = vocab_size,
340
+ BLOCK_SIZE = BLOCK_SIZE,
341
+ DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
342
+ SOFTCAP = ctx.logit_softcapping,
343
+ DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING,
344
+ LOGIT_SCALE = ctx.logit_scaling,
345
+ num_warps = 8,
346
+ )
347
+ return logits, None, None, None,
348
+ pass
349
+ pass
350
+
351
+
352
+ @torch._disable_dynamo
353
+ def fast_cross_entropy_loss(
354
+ logits,
355
+ labels,
356
+ logit_softcapping = 0,
357
+ logit_scaling = 0,
358
+ ):
359
+ """
360
+ Arguments:
361
+ logits: (batch, seq_len, vocab_size)
362
+ labels: (batch, seq_len,)
363
+ Returns:
364
+ losses: float
365
+ """
366
+ batch, seq_len, d = logits.shape
367
+ assert(labels.shape == (batch, seq_len))
368
+
369
+ loss = Fast_CrossEntropyLoss.apply(
370
+ logits.view(batch*seq_len, d),
371
+ labels.view(-1),
372
+ logit_softcapping,
373
+ logit_scaling,
374
+ )
375
+ n_items = torch.count_nonzero(labels != -100)
376
+ return loss.sum() / n_items
377
+ pass
378
+
379
+
380
+ from transformers.models.llama.modeling_llama import (
381
+ LlamaForCausalLM,
382
+ CausalLMOutputWithPast,
383
+ Optional,
384
+ Union,
385
+ Cache,
386
+ List,
387
+ Tuple,
388
+ )
389
+ import inspect, re
390
+ function = inspect.getsource(LlamaForCausalLM.forward)
391
+ function = function.split("\n")
392
+ i = re.match(r"[ ]{1,}", function[0]).span(0)[1]
393
+ function = [x[i:] for x in function]
394
+ function = "\n".join(function)
395
+ function = function[function.find("def forward"):]
396
+ replacement = """ loss = None
397
+ logit_softcapping = getattr(self.config, "final_logit_softcapping", 0)
398
+ logit_scaling = getattr(self.config, "logit_scale", 0)
399
+ if labels is not None:
400
+ shift_logits = logits
401
+ if not hasattr(self, "extra_ignored_labels"):
402
+ # Fixes https://github.com/unslothai/unsloth/issues/10
403
+ self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda:0")
404
+ pass
405
+
406
+ shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
407
+ loss = fast_cross_entropy_loss(
408
+ logits = shift_logits,
409
+ labels = shift_labels,
410
+ logit_softcapping = logit_softcapping,
411
+ logit_scaling = logit_scaling,
412
+ )
413
+ else:
414
+ if logit_scaling != 0:
415
+ if logits.requires_grad:
416
+ logits = logit_scaling * logits
417
+ else:
418
+ logits *= logit_scaling
419
+ pass
420
+ pass
421
+ if logit_softcapping != 0:
422
+ if logits.requires_grad:
423
+ logits = (1.0 / logit_softcapping) * logits
424
+ logits = torch.tanh(logits)
425
+ logits = logit_softcapping * logits
426
+ else:
427
+ logits *= (1.0 / logit_softcapping)
428
+ torch.tanh(logits, out = logits)
429
+ logits *= logit_softcapping
430
+ pass
431
+ pass
432
+ pass
433
+ """
434
+ function = \
435
+ function[:function.find(" loss = None")] + \
436
+ replacement + \
437
+ function[ function.find(" if not return_dict"):]
438
+ function = function.replace("logits = logits.float()", "\n")
439
+ # Missed spaces
440
+ function = function.split("\n")
441
+ # Not the first one though!
442
+ function = [function[0]] + [" "*4 + x for x in function[1:]]
443
+ function = "\n".join(function)
444
+ function = f"class Unsloth_LlamaForCausalLM(LlamaForCausalLM):\n"\
445
+ f" {function}\n"
446
+ exec(function, globals())
447
+ del function, replacement, inspect, re
448
+
449
+
450
+ def patch_llama_for_causal_lm():
451
+ import transformers.models.llama.modeling_llama
452
+ transformers.models.llama.modeling_llama.LlamaForCausalLM = Unsloth_LlamaForCausalLM
453
+ return
454
+ pass
455
+
456
+
457
+ def unpatch_llama_for_causal_lm():
458
+ import transformers.models.llama.modeling_llama
459
+ transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
460
+ return
461
+ pass
unsloth-main/unsloth-main/unsloth/kernels/fast_lora.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from .utils import (
17
+ fast_dequantize,
18
+ QUANT_STATE,
19
+ get_lora_parameters,
20
+ get_lora_parameters_bias,
21
+ matmul_lora,
22
+ torch_amp_custom_fwd,
23
+ torch_amp_custom_bwd,
24
+ )
25
+
26
+
27
+ class LoRA_MLP(torch.autograd.Function):
28
+ """
29
+ ### LoRA weights
30
+ G = G + Ag @ Bg
31
+ U = U + Au @ Bu
32
+ W = W + Aw @ Bw
33
+
34
+ ### SwiGLU(X)
35
+ e = X @ G
36
+ f = e * sigmoid(e)
37
+ g = X @ U
38
+ h = f * g
39
+ i = h @ W
40
+
41
+ ### Backpropagation chain rule
42
+ See our blog post for more details
43
+
44
+ df = sigmoid(e) * (1 - f) + f
45
+ dC/dW = h.T @ dY
46
+ dC/dU = X.T @ (D @ W.T * f)
47
+ dC/dG = X.T @ (D @ W.T * df * g)
48
+
49
+ ### Down projection LoRA weights
50
+ dC/dAw = dC/dW @ B.T
51
+ dC/dBw = A.T @ dC/dW
52
+ dC/dAw = h.T @ dY @ B.T
53
+ dC/dBw = A.T @ h.T @ dY
54
+
55
+ ### Up projection LoRA weights
56
+ dC/dAu = X.T @ (D @ W.T * f) @ B.T
57
+ dC/dBu = A.T @ X.T @ (D @ W.T * f)
58
+
59
+ ### Gate projection LoRA weights
60
+ dC/dAg = X.T @ (D @ W.T * df * g) @ B.T
61
+ dC/dBg = A.T @ X.T @ (D @ W.T * df * g)
62
+
63
+ Don't forget to see our blog post for more details!
64
+ """
65
+ @staticmethod
66
+ @torch_amp_custom_fwd
67
+ def forward(ctx, X : torch.Tensor,
68
+ gateW, gateW_quant, gateA, gateB, gateS,
69
+ upW, upW_quant, upA, upB, upS,
70
+ downW, downW_quant, downA, downB, downS,
71
+ _forward_function, _backward_function,
72
+ inplace = True,):
73
+ dtype = X.dtype
74
+
75
+ e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS)
76
+ g = matmul_lora(X, upW, upW_quant, upA, upB, upS)
77
+ h = _forward_function(e, g)
78
+ i = matmul_lora(h, downW, downW_quant, downA, downB, downS)
79
+
80
+ ctx.custom_saved_tensors = (
81
+ gateW, gateW_quant, gateS,
82
+ upW, upW_quant, upS,
83
+ downW, downW_quant, downS,
84
+ _backward_function,
85
+ )
86
+ ctx.save_for_backward(gateA, gateB, upA, upB, downA, downB,
87
+ X, e, g)
88
+ ctx.inplace = inplace
89
+ return i
90
+ pass
91
+
92
+
93
+ @staticmethod
94
+ @torch_amp_custom_bwd
95
+ def backward(ctx, dY : torch.Tensor):
96
+ gateW, gateW_quant, gateS, upW, upW_quant, upS, downW, downW_quant, downS, \
97
+ _backward_function = ctx.custom_saved_tensors
98
+ gateA, gateB, upA, upB, downA, downB, \
99
+ X, e, g = ctx.saved_tensors
100
+
101
+ gateA, gateB, upA, upB, downA, downB = \
102
+ gateA.t(), gateB.t(), upA.t(), upB.t(), downA.t(), downB.t()
103
+
104
+ batch, seq_len, hd = X.shape
105
+ dY = dY.view(-1, dY.shape[-1])
106
+ X = X .view(-1, X .shape[-1])
107
+ e = e .view(-1, e .shape[-1])
108
+ g = g .view(-1, g .shape[-1])
109
+ dtype = X.dtype
110
+
111
+ DW = matmul_lora(dY, downW.t(), downW_quant, downB, downA, downS)
112
+ DW, e, g = _backward_function(DW, e, g)
113
+ h, df, de = DW, e, g
114
+
115
+ # Down projection LoRA weights
116
+ d_downA = h.t() @ (dY @ downB.t())
117
+ d_downB = (downA.t() @ h.t()) @ dY
118
+ d_downA *= downS
119
+ d_downB *= downS
120
+
121
+ # Up projection LoRA weights
122
+ d_upA = X.t() @ (df @ upB.t())
123
+ d_upB = (upA.t() @ X.t()) @ df
124
+ d_upA *= upS
125
+ d_upB *= upS
126
+
127
+ # Gate projection LoRA weights
128
+ d_gateA = X.t() @ (de @ gateB.t())
129
+ d_gateB = (gateA.t() @ X.t()) @ de
130
+ d_gateA *= gateS
131
+ d_gateB *= gateS
132
+
133
+ # dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS)
134
+ # dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS)
135
+ upW = fast_dequantize(upW.t(), upW_quant)
136
+ dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None)
137
+ del upW
138
+ dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())
139
+
140
+ gateW = fast_dequantize(gateW.t(), gateW_quant)
141
+ dX += de @ gateW.t()
142
+ del gateW
143
+ dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
144
+
145
+ # gateW, gateW_quant, gateA, gateB, gateS,
146
+ # upW, upW_quant, upA, upB, upS,
147
+ # downW, downW_quant, downA, downB, downS,
148
+ return dX.view(batch, seq_len, hd), \
149
+ None, None, d_gateA.t(), d_gateB.t(), None, \
150
+ None, None, d_upA.t(), d_upB.t(), None, \
151
+ None, None, d_downA.t(), d_downB.t(), None, \
152
+ None, None, None, # _backward and _forward and inplace
153
+ pass
154
+ pass
155
+
156
+
157
+ from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
158
+ def apply_lora_mlp_swiglu(self, X, inplace = True):
159
+ gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
160
+ upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
161
+ downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
162
+ out = LoRA_MLP.apply(X,
163
+ gateW, gateW_quant, gateA, gateB, gateS,
164
+ upW, upW_quant, upA, upB, upS,
165
+ downW, downW_quant, downA, downB, downS,
166
+ swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel,
167
+ inplace,)
168
+ return out
169
+ pass
170
+
171
+
172
+ from .geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel
173
+ def apply_lora_mlp_geglu_exact(self, X, inplace = True):
174
+ gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
175
+ upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
176
+ downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
177
+ out = LoRA_MLP.apply(X,
178
+ gateW, gateW_quant, gateA, gateB, gateS,
179
+ upW, upW_quant, upA, upB, upS,
180
+ downW, downW_quant, downA, downB, downS,
181
+ geglu_exact_forward_kernel, geglu_exact_backward_kernel,
182
+ inplace,)
183
+ return out
184
+ pass
185
+
186
+
187
+ from .geglu import geglu_approx_forward_kernel, geglu_approx_backward_kernel
188
+ def apply_lora_mlp_geglu_approx(self, X):
189
+ gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
190
+ upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
191
+ downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
192
+ out = LoRA_MLP.apply(X,
193
+ gateW, gateW_quant, gateA, gateB, gateS,
194
+ upW, upW_quant, upA, upB, upS,
195
+ downW, downW_quant, downA, downB, downS,
196
+ geglu_approx_forward_kernel, geglu_approx_backward_kernel,)
197
+ return out
198
+ pass
199
+
200
+
201
+ class LoRA_QKV(torch.autograd.Function):
202
+ """
203
+ ### LoRA weights
204
+ Wq = Wq + Aq @ Bq
205
+ Wk = Wk + Ak @ Bk
206
+ Wv = Wv + Av @ Bv
207
+ Q = X @ Wq = X @ Wq + X @ Aq @ Bq
208
+ K = X @ Wk = X @ Wk + X @ Ak @ Bk
209
+ V = X @ Wv = X @ Wv + X @ Av @ Bv
210
+
211
+ ### Backpropagation chain rule
212
+ See our blogpost for more details.
213
+
214
+ dC/dWq = X.T @ D(Wq)
215
+ dC/dWk = X.T @ D(Wk)
216
+ dC/dWv = X.T @ D(Wv)
217
+ We then sum them all find dC/dX
218
+
219
+ ### Q projection LoRA weights
220
+ dC/dAq = X.T @ D(Wq) @ B.T
221
+ dC/dBq = A.T @ X.T @ D(Wq)
222
+
223
+ ### K projection LoRA weights
224
+ dC/dAk = X.T @ D(Wk) @ B.T
225
+ dC/dBk = A.T @ X.T @ D(Wk)
226
+
227
+ ### V projection LoRA weights
228
+ dC/dAv = X.T @ D(Wv) @ B.T
229
+ dC/dBv = A.T @ X.T @ D(Wv)
230
+ """
231
+ @staticmethod
232
+ @torch_amp_custom_fwd
233
+ def forward(ctx, X : torch.Tensor,
234
+ QW, QW_quant, QA, QB, QS,
235
+ KW, KW_quant, KA, KB, KS,
236
+ VW, VW_quant, VA, VB, VS,
237
+ inplace = True):
238
+ dtype = X.dtype
239
+
240
+ Q = matmul_lora(X, QW, QW_quant, QA, QB, QS)
241
+ K = matmul_lora(X, KW, KW_quant, KA, KB, KS)
242
+ V = matmul_lora(X, VW, VW_quant, VA, VB, VS)
243
+
244
+ ctx.custom_saved_tensors = (
245
+ QW, QW_quant, QS,
246
+ KW, KW_quant, KS,
247
+ VW, VW_quant, VS,
248
+ )
249
+ ctx.save_for_backward(X, QA, QB, KA, KB, VA, VB,)
250
+ ctx.inplace = inplace
251
+ return Q, K, V
252
+ pass
253
+
254
+ @staticmethod
255
+ @torch_amp_custom_bwd
256
+ def backward(ctx, dQ, dK, dV):
257
+ QW, QW_quant, QS, KW, KW_quant, KS, VW, VW_quant, VS = \
258
+ ctx.custom_saved_tensors
259
+ X, QA, QB, KA, KB, VA, VB, = ctx.saved_tensors
260
+
261
+ QA, QB, KA, KB, VA, VB = \
262
+ QA.t(), QB.t(), KA.t(), KB.t(), VA.t(), VB.t()
263
+
264
+ batch, seq_len, hd = X.shape
265
+ dQ = dQ.view(-1, dQ.shape[-1])
266
+ dK = dK.reshape(-1, dK.shape[-1]) # view doesn't work on K.T
267
+ dV = dV.view(-1, dV.shape[-1])
268
+ X = X .view(-1, X .shape[-1])
269
+ dtype = X.dtype
270
+
271
+ ### Weight projection LoRA weights
272
+ # See our blogpost for more details.
273
+
274
+ # Q Projection
275
+ d_QA = X.t() @ (dQ @ QB.t())
276
+ d_QB = (QA.t() @ X.t()) @ dQ
277
+ d_QA *= QS
278
+ d_QB *= QS
279
+
280
+ # K Projection
281
+ d_KA = X.t() @ (dK @ KB.t())
282
+ d_KB = (KA.t() @ X.t()) @ dK
283
+ d_KA *= KS
284
+ d_KB *= KS
285
+
286
+ # V Projection
287
+ d_VA = X.t() @ (dV @ VB.t())
288
+ d_VB = (VA.t() @ X.t()) @ dV
289
+ d_VA *= VS
290
+ d_VB *= VS
291
+
292
+ # Combine derivatives to find dX
293
+ # dQ
294
+ QW = fast_dequantize(QW.t(), QW_quant)
295
+ dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None)
296
+ del QW
297
+ dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))
298
+
299
+ # dK
300
+ KW = fast_dequantize(KW.t(), KW_quant)
301
+ dX += dK @ KW.t()
302
+ del KW
303
+ dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())
304
+
305
+ # dV
306
+ VW = fast_dequantize(VW.t(), VW_quant)
307
+ dX += dV @ VW.t()
308
+ del VW
309
+ dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())
310
+
311
+ # QW, QW_quant, QA, QB, QS,
312
+ # KW, KW_quant, KA, KB, KS,
313
+ # VW, VW_quant, VA, VB, VS,
314
+ return dX.view(batch, seq_len, hd), \
315
+ None, None, d_QA.t(), d_QB.t(), None, \
316
+ None, None, d_KA.t(), d_KB.t(), None, \
317
+ None, None, d_VA.t(), d_VB.t(), None, \
318
+ None,
319
+ pass
320
+ pass
321
+
322
+
323
+ def apply_lora_qkv(self, X, inplace = True):
324
+ QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
325
+ KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
326
+ VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
327
+ Q, K, V = LoRA_QKV.apply(X,
328
+ QW, QW_quant, QA, QB, QS,
329
+ KW, KW_quant, KA, KB, KS,
330
+ VW, VW_quant, VA, VB, VS,
331
+ inplace,
332
+ )
333
+ return Q, K, V
334
+ pass
335
+
336
+
337
+ class LoRA_W(torch.autograd.Function):
338
+ """
339
+ ### LoRA weights
340
+ Wq = Wq + Aq @ Bq
341
+ Wk = Wk + Ak @ Bk
342
+ Wv = Wv + Av @ Bv
343
+ Q = X @ Wq = X @ Wq + X @ Aq @ Bq
344
+ K = X @ Wk = X @ Wk + X @ Ak @ Bk
345
+ V = X @ Wv = X @ Wv + X @ Av @ Bv
346
+
347
+ ### Backpropagation chain rule
348
+ dC/dWq = X.T @ D(Wq)
349
+ dC/dWk = X.T @ D(Wk)
350
+ dC/dWv = X.T @ D(Wv)
351
+
352
+ ### Q projection LoRA weights
353
+ dC/dAq = X.T @ D(Wq) @ B.T
354
+ dC/dBq = A.T @ X.T @ D(Wq)
355
+
356
+ ### K projection LoRA weights
357
+ dC/dAk = X.T @ D(Wk) @ B.T
358
+ dC/dBk = A.T @ X.T @ D(Wk)
359
+
360
+ ### V projection LoRA weights
361
+ dC/dAv = X.T @ D(Wv) @ B.T
362
+ dC/dBv = A.T @ X.T @ D(Wv)
363
+ """
364
+ @staticmethod
365
+ @torch_amp_custom_fwd
366
+ def forward(ctx, X : torch.Tensor,
367
+ W, W_quant, A, B, S):
368
+ dtype = X.dtype
369
+ XW = matmul_lora(X, W, W_quant, A, B, S)
370
+ ctx.custom_saved_tensors = (W, W_quant, S,)
371
+ ctx.save_for_backward(A, B, X)
372
+ return XW
373
+ pass
374
+
375
+ @staticmethod
376
+ @torch_amp_custom_bwd
377
+ def backward(ctx, dY : torch.Tensor):
378
+ W, W_quant, S = ctx.custom_saved_tensors
379
+ A, B, X = ctx.saved_tensors
380
+
381
+ A, B = A.t(), B.t()
382
+
383
+ batch, seq_len, hd = X.shape
384
+ dY = dY.reshape(-1, dY.shape[-1]) # Must be reshape
385
+ X = X .reshape(-1, X .shape[-1]) # Must be reshape
386
+ dtype = X.dtype
387
+
388
+ ### Weight projection LoRA weights
389
+ # Weight projection
390
+ d_A = X.t() @ (dY @ B.t())
391
+ d_B = (A.t() @ X.t()) @ dY
392
+ d_A *= S
393
+ d_B *= S
394
+
395
+ # Get derivative for dX
396
+ W = fast_dequantize(W.t(), W_quant)
397
+ dX = dY @ W.t()
398
+ del W
399
+ dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t())
400
+
401
+ # W, W_quant, A, B, S
402
+ return dX.view(batch, seq_len, hd), \
403
+ None, None, d_A.t(), d_B.t(), None
404
+ pass
405
+ pass
406
+
407
+
408
+ def apply_lora_o(self, X):
409
+ OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
410
+ O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
411
+ return O
412
+ pass
unsloth-main/unsloth-main/unsloth/kernels/flex_attention.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from functools import lru_cache
17
+ from transformers.models.llama.modeling_llama import logger
18
+
19
+ torch_compile_options = {
20
+ "epilogue_fusion" : True,
21
+ "max_autotune" : True,
22
+ "shape_padding" : True,
23
+ "trace.enabled" : False, # Output Triton kernel outputs!
24
+ "triton.cudagraphs" : False,
25
+ }
26
+
27
+ # Flex Attention supported from torch 2.5 onwards only
28
+ try:
29
+ from torch.nn.attention.flex_attention import (
30
+ flex_attention as _flex_attention,
31
+ create_block_mask as _create_block_mask,
32
+ )
33
+ _flex_attention = torch.compile(_flex_attention, dynamic = True, options = torch_compile_options)
34
+ HAS_FLEX_ATTENTION = True
35
+ except:
36
+ HAS_FLEX_ATTENTION = False
37
+ pass
38
+
39
+
40
+ if not HAS_FLEX_ATTENTION:
41
+
42
+ # Logit softcapping
43
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
44
+ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
45
+ n_heads = self.num_heads
46
+ head_dim = self.head_dim
47
+ n_kv_heads = self.num_key_value_heads
48
+ n_groups = self.num_key_value_groups
49
+
50
+ # Grouped query attention
51
+ K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
52
+ V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
53
+ K = K.reshape(bsz, n_heads, q_len, head_dim)
54
+ V = V.reshape(bsz, n_heads, q_len, head_dim)
55
+
56
+ # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
57
+ # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
58
+ # We default to using the config file itself
59
+ # s = self.config.hidden_size // self.config.num_attention_heads
60
+ s = self.config.query_pre_attn_scalar
61
+ t = self.config.attn_logit_softcapping
62
+
63
+ Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
64
+ A = torch.matmul(Q, K.transpose(2, 3))
65
+ A = t * torch.tanh(A / t) # Logit softcapping
66
+ A += causal_mask[:q_len, :q_len]
67
+ # Much slower in torch compile!
68
+ # A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
69
+ A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
70
+ A = torch.matmul(A, V)
71
+ A = A.transpose(1, 2).contiguous()
72
+ A = A.reshape(bsz, q_len, n_heads*head_dim)
73
+ return A
74
+ pass
75
+
76
+ create_flex_attention_causal_mask = None
77
+ create_flex_attention_sliding_window_mask = None
78
+ else:
79
+ # See https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
80
+ # for more examples
81
+ # BSD 3-Clause License Copyright (c) 2023, Driss Guessous, Horace He et al
82
+ import functools, math
83
+
84
+ def generate_tanh_softcap(t):
85
+ def tanh_softcap(x, b, h, q_idx, kv_idx):
86
+ return t * torch.tanh(x / t)
87
+ return tanh_softcap
88
+ pass
89
+ def causal_masker(b, h, q_idx, kv_idx):
90
+ return q_idx >= kv_idx
91
+ pass
92
+
93
+ @functools.lru_cache
94
+ def sliding_window_masker(size = 4096):
95
+ def sliding_window(b, h, q_idx, kv_idx):
96
+ causal_mask = q_idx >= kv_idx
97
+ window_mask = q_idx - kv_idx <= size
98
+ return causal_mask & window_mask
99
+ return sliding_window
100
+ pass
101
+
102
+ @functools.lru_cache
103
+ def create_block_mask(mask, n = 128):
104
+ return _create_block_mask(
105
+ mask, 1, 1, n, n,
106
+ BLOCK_SIZE = 128,
107
+ _compile = True,
108
+ )
109
+ pass
110
+
111
+ def create_flex_attention_causal_mask(max_seq_length = 8192):
112
+ causal_mask = create_block_mask(causal_masker, max_seq_length)
113
+ return causal_mask
114
+ pass
115
+
116
+ def create_flex_attention_sliding_window_mask(max_seq_length = 8192, sliding_window = 4096):
117
+ sliding_masker = sliding_window_masker(sliding_window)
118
+ causal_mask = create_block_mask(sliding_masker, max_seq_length)
119
+ return causal_mask
120
+ pass
121
+
122
+ @functools.lru_cache
123
+ def flex_attention(s, t):
124
+ scale = 1.0 / math.sqrt(s)
125
+ score_mod = generate_tanh_softcap(t)
126
+ return functools.partial(
127
+ _flex_attention, score_mod = score_mod, scale = scale, enable_gqa = True,
128
+ )
129
+ pass
130
+
131
+ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
132
+ n_heads = self.num_heads
133
+ head_dim = self.head_dim
134
+ s = self.config.query_pre_attn_scalar
135
+ t = self.config.attn_logit_softcapping
136
+ fx = flex_attention(s, t)
137
+ A = fx(query = Q, key = K, value = V, block_mask = causal_mask)
138
+ A = A.transpose(1, 2).contiguous()
139
+ A = A.reshape(bsz, q_len, n_heads*head_dim)
140
+ return A
141
+ pass
142
+ pass
143
+
144
+
145
+ torch_matmul = torch.matmul
146
+ torch_tanh = torch.tanh
147
+ torch_nn_functional_softmax = torch.nn.functional.softmax
148
+ def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
149
+ n_heads = self.num_heads
150
+ head_dim = self.head_dim
151
+ n_kv_heads = self.num_key_value_heads
152
+ n_groups = self.num_key_value_groups
153
+
154
+ # Grouped query attention
155
+ K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
156
+ V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
157
+ K = K.reshape(bsz, n_heads, q_len, head_dim)
158
+ V = V.reshape(bsz, n_heads, q_len, head_dim)
159
+
160
+ # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
161
+ # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
162
+ # We default to using the config file itself
163
+ # s = self.config.hidden_size // self.config.num_attention_heads
164
+ s = self.config.query_pre_attn_scalar
165
+ t = self.config.attn_logit_softcapping
166
+
167
+ Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
168
+ A = torch_matmul(Q, K.transpose(2, 3))
169
+
170
+ # Logit softcapping
171
+ A /= t; torch_tanh(A, out = A); A *= t;
172
+ A += causal_mask[:q_len, :q_len]
173
+ # Much slower in torch compile!
174
+ # A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
175
+ A = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
176
+ A = torch_matmul(A, V)
177
+ A = A.transpose(1, 2).contiguous()
178
+ A = A.reshape(bsz, q_len, n_heads*head_dim)
179
+ return A
180
+ pass
unsloth-main/unsloth-main/unsloth/kernels/geglu.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import triton
16
+ import triton.language as tl
17
+ import torch
18
+ from .utils import calculate_settings, triton_tanh
19
+
20
+
21
+ @triton.jit
22
+ def _exact_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
23
+ block_idx = tl.program_id(0)
24
+ offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
25
+ mask = offsets < n_elements
26
+
27
+ # f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
28
+ # h = f * up
29
+ e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
30
+ g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
31
+
32
+ f_row = 0.5 * e_row * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
33
+ f_row = f_row.to(g_row.dtype) # Exact copy from HF
34
+ h_row = f_row * g_row
35
+
36
+ # Store h
37
+ tl.store(h + offsets, h_row, mask = mask)
38
+ pass
39
+
40
+
41
+ def geglu_exact_forward_kernel(gate, up):
42
+ batch, seq_len, hd = gate.shape
43
+ n_elements = gate.numel()
44
+ out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0")
45
+ grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
46
+ _exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
47
+ return out
48
+ pass
49
+
50
+
51
+ @triton.jit
52
+ def _exact_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
53
+ """
54
+ f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
55
+ h = f * up
56
+
57
+ df/de (with help of Wolfram :)
58
+ df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2)
59
+
60
+ Reuse via
61
+ f = 1/2 * (1 + erf(1/sqrt(2) * e)) * e
62
+ """
63
+ block_idx = tl.program_id(0)
64
+ offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
65
+ mask = offsets < n_elements
66
+
67
+ DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
68
+ e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
69
+ g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
70
+
71
+ # Break e_row away for re-use
72
+ # f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
73
+ f_partial_row = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
74
+ f_row = f_partial_row * e_row
75
+
76
+ f_row = f_row.to(DW_row.dtype)
77
+ # h = f * g
78
+ h_row = f_row * g_row
79
+ # df = DW * f
80
+ df_row = DW_row * f_row
81
+ # dg = DW * g
82
+ dg_row = DW_row * g_row
83
+
84
+ # df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2)
85
+ t = 0.3989422804014327 # 1/sqrt(2*pi)
86
+ df_de = f_partial_row + t * e_row * tl.exp(-0.5 * e_row * e_row)
87
+
88
+ de_row = dg_row.to(tl.float32) * df_de
89
+ de_row = de_row.to(DW_row.dtype)
90
+
91
+ # Store derivatives in buffers
92
+ tl.store(DW + offsets, h_row, mask = mask) # h = f * g
93
+ tl.store(e + offsets, df_row, mask = mask) # df = DW * f
94
+ tl.store(g + offsets, de_row, mask = mask) # de
95
+ pass
96
+
97
+
98
+ def geglu_exact_backward_kernel(DW, e, g):
99
+ batch_seq_len, hd = e.shape
100
+ n_elements = e.numel()
101
+ grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
102
+ _exact_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
103
+ return DW, e, g
104
+ pass
105
+
106
+
107
+ @triton.jit
108
+ def _approx_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
109
+ block_idx = tl.program_id(0)
110
+ offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
111
+ mask = offsets < n_elements
112
+
113
+ # f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))
114
+ # f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))
115
+ # h = f * up
116
+ s = 0.7978845608028654 # math.sqrt(2 / math.pi)
117
+
118
+ e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
119
+ g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
120
+
121
+ f_row = 0.5 * e_row * (
122
+ triton_tanh(s * e_row * (1.0 + 0.044715 * e_row * e_row)) \
123
+ + 1.0
124
+ )
125
+ f_row = f_row.to(g_row.dtype) # Exact copy from HF
126
+ h_row = f_row * g_row
127
+
128
+ # Store h
129
+ tl.store(h + offsets, h_row, mask = mask)
130
+ pass
131
+
132
+
133
+ def geglu_approx_forward_kernel(gate, up):
134
+ batch, seq_len, hd = gate.shape
135
+ n_elements = gate.numel()
136
+ out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0")
137
+ grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
138
+ _approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
139
+ return out
140
+ pass
141
+
142
+
143
+ @triton.jit
144
+ def _approx_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
145
+ """
146
+ f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))
147
+ h = f * up
148
+
149
+ df/de (with help from https://arxiv.org/pdf/2305.12073.pdf :))
150
+ df/de = 1/2 * [1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )] +
151
+ 1/2 * sech^2 [ sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ] * \
152
+ ( sqrt(2/pi) * x * (1 + 0.044715 * x^2 * 3 ) )
153
+
154
+ Notice sech^2(x) = 1 - tanh^2(x)
155
+ So reuse tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )
156
+
157
+ See https://www.desmos.com/calculator/nqprfoni6x
158
+ """
159
+ block_idx = tl.program_id(0)
160
+ offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
161
+ mask = offsets < n_elements
162
+
163
+ DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
164
+ e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
165
+ g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
166
+
167
+ # See https://www.desmos.com/calculator/nqprfoni6x
168
+ s = 0.7978845608028654 # math.sqrt(2 / math.pi)
169
+ a = s * e_row # a = sqrt(2 / pi) * x
170
+ b = a * 0.044715 * e_row * e_row # b = a * 0.044715 * x^2
171
+ T = 1.0 + triton_tanh(a + b)
172
+ T2 = 0.5 * T
173
+ # Q = 0.5 * -T * (T - 2.0) * (a + 3.0 * b)
174
+ Q2 = -T2 * (T - 2.0) * (a + 3.0 * b)
175
+ df_de = T2 + Q2 # 1/2 * (T + Q)
176
+
177
+ # f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))
178
+ f_row = T2 * e_row
179
+ f_row = f_row.to(DW_row.dtype)
180
+ # h = f * g
181
+ h_row = f_row * g_row
182
+ # df = DW * f
183
+ df_row = DW_row * f_row
184
+ # dg = DW * g
185
+ dg_row = DW_row * g_row
186
+
187
+ de_row = dg_row.to(tl.float32) * df_de
188
+ de_row = de_row.to(DW_row.dtype)
189
+
190
+ # Store derivatives in buffers
191
+ tl.store(DW + offsets, h_row, mask = mask) # h = f * g
192
+ tl.store(e + offsets, df_row, mask = mask) # df = DW * f
193
+ tl.store(g + offsets, de_row, mask = mask) # de
194
+ pass
195
+
196
+
197
+ def geglu_approx_backward_kernel(DW, e, g):
198
+ batch_seq_len, hd = e.shape
199
+ n_elements = e.numel()
200
+ grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
201
+ _approx_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
202
+ return DW, e, g
203
+ pass
unsloth-main/unsloth-main/unsloth/kernels/layernorm.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ # Copyright 2024-present Andrej Karpathy & the llm.c team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import triton
17
+ import triton.language as tl
18
+ import torch
19
+ from .utils import calculate_settings
20
+
21
+
22
+ @triton.jit
23
+ def layernorm_forward(
24
+ Y, Y_row_stride,
25
+ X, X_row_stride,
26
+ W,
27
+ b,
28
+ r,
29
+ mu,
30
+ n_cols, eps,
31
+ BLOCK_SIZE : tl.constexpr
32
+ ):
33
+ row_idx = tl.program_id(0)
34
+ col_offsets = tl.arange(0, BLOCK_SIZE)
35
+ mask = col_offsets < n_cols
36
+
37
+ Y += row_idx * Y_row_stride
38
+ X += row_idx * X_row_stride
39
+ r += row_idx
40
+ mu += row_idx
41
+
42
+ # According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules
43
+ # are in float32!
44
+ X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
45
+ W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
46
+ b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)
47
+
48
+ mean_X = tl.sum(X_row, axis = 0) / n_cols
49
+ XX = X_row - mean_X
50
+ row_var = tl.sum(XX * XX, axis = 0) / n_cols
51
+ inv_var = tl.math.rsqrt(row_var + eps)
52
+ tl.store (r, inv_var)
53
+ tl.store (mu, mean_X)
54
+ output = (XX * inv_var) * W_row + b_row
55
+ tl.store(Y + col_offsets, output, mask = mask)
56
+ pass
57
+
58
+
59
+ @triton.jit
60
+ def layernorm_backward(
61
+ dY, dY_row_stride,
62
+ X, X_row_stride,
63
+ W,
64
+ b,
65
+ r,
66
+ mu,
67
+ n_cols, eps,
68
+ BLOCK_SIZE : tl.constexpr
69
+ ):
70
+ # Approximately follows https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
71
+ row_idx = tl.program_id(0)
72
+ col_offsets = tl.arange(0, BLOCK_SIZE)
73
+ mask = col_offsets < n_cols
74
+
75
+ dY += row_idx * dY_row_stride
76
+ X += row_idx * X_row_stride
77
+ r += row_idx
78
+ mu += row_idx
79
+
80
+ # According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules
81
+ # are in float32!
82
+ dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
83
+ X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
84
+ W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
85
+ b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)
86
+
87
+ inv_var = tl.load(r) .to(tl.float32)
88
+ mean = tl.load(mu).to(tl.float32)
89
+ normed = (X_row - mean) * inv_var
90
+ dY_W = dY_row * W_row
91
+ dX_row = dY_W - tl.sum(dY_W, axis = 0) / n_cols - normed * tl.sum(dY_W * normed, axis = 0) / n_cols
92
+ dX_row = dX_row * inv_var
93
+ tl.store(dY + col_offsets, dX_row, mask = mask)
94
+ pass
95
+
96
+
97
+ class Fast_Layernorm(torch.autograd.Function):
98
+ @staticmethod
99
+ def forward(ctx, X, W, b, eps):
100
+ shape = X.shape
101
+ dim = shape[-1]
102
+ X = X.view(-1, dim)
103
+ n_rows, n_cols = X.shape
104
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
105
+
106
+ Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0")
107
+ r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
108
+ mu = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
109
+
110
+ layernorm_forward[(n_rows,)](
111
+ Y, Y.stride(0),
112
+ X, X.stride(0),
113
+ W,
114
+ b,
115
+ r,
116
+ mu,
117
+ n_cols, eps,
118
+ BLOCK_SIZE = BLOCK_SIZE,
119
+ num_warps = num_warps,
120
+ )
121
+ ctx.eps = eps
122
+ ctx.BLOCK_SIZE = BLOCK_SIZE
123
+ ctx.num_warps = num_warps
124
+ ctx.save_for_backward(X, W, b, r, mu)
125
+ return Y.view(*shape)
126
+ pass
127
+
128
+ @staticmethod
129
+ def backward(ctx, dY):
130
+ shape = dY.shape
131
+ dim = shape[-1]
132
+ dY = dY.view(-1, dim)
133
+ X, W, b, r, mu = ctx.saved_tensors
134
+ n_rows, n_cols = dY.shape
135
+
136
+ layernorm_backward[(n_rows,)](
137
+ dY, dY.stride(0),
138
+ X, X .stride(0),
139
+ W,
140
+ b,
141
+ r,
142
+ mu,
143
+ n_cols, ctx.eps,
144
+ BLOCK_SIZE = ctx.BLOCK_SIZE,
145
+ num_warps = ctx.num_warps,
146
+ )
147
+ dX = dY.view(*shape)
148
+ return dX, None, None, None, None
149
+ pass
150
+ pass
151
+
152
+
153
+ def fast_layernorm(layernorm, X):
154
+ assert(layernorm.elementwise_affine is True)
155
+ W = layernorm.weight
156
+ bias = layernorm.bias
157
+ eps = layernorm.variance_epsilon if \
158
+ hasattr(layernorm, "variance_epsilon") \
159
+ else layernorm.eps
160
+ out = Fast_Layernorm.apply(X, W, bias, eps)
161
+ return out
162
+ pass
163
+
164
+
165
+ from torch.nn import LayerNorm
166
+ class Unsloth_LayerNorm(LayerNorm):
167
+ def forward(self, X):
168
+ return fast_layernorm(self, X)
169
+ pass
170
+ pass
171
+
172
+
173
+ def patch_layernorm():
174
+ import torch.nn
175
+ torch.nn.LayerNorm = Unsloth_LayerNorm
176
+ return
177
+ pass
178
+
179
+
180
+ def unpatch_layernorm():
181
+ import torch.nn
182
+ torch.nn.LayerNorm = LayerNorm
183
+ return
184
+ pass
185
+
186
+
187
+ def test_layernorm(
188
+ dim = 1024, eps = 1e-5, dtype = torch.float16,
189
+ bsz = 21, random_state = 3407, seqlen = 3341,
190
+ ):
191
+ from torch.nn import LayerNorm
192
+ layernorm = LayerNorm((dim,), eps = eps, device = "cuda", dtype = dtype)
193
+ torch.cuda.manual_seed(random_state)
194
+ torch.manual_seed(random_state)
195
+ torch.nn.init.uniform_(layernorm.weight)
196
+ torch.nn.init.uniform_(layernorm.bias)
197
+ X = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda")
198
+ XX = X.clone()
199
+ X .requires_grad_(True)
200
+ XX.requires_grad_(True)
201
+ Y = layernorm(X)
202
+ YY = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda", requires_grad = True)
203
+ Y.backward(YY)
204
+ correct_grad = X.grad.clone()
205
+ # from unsloth.kernels import fast_layernorm
206
+ Y = fast_layernorm(layernorm, XX)
207
+ Y.backward(YY)
208
+ assert(torch.dist(correct_grad, XX.grad).item() <= 0.1)
209
+ pass
210
+
211
+
212
+ def testing_suite_layernorm():
213
+ for dim in [512, 1024, 2048]:
214
+ for dtype in [torch.float16, torch.bfloat16]:
215
+ with torch.autocast(device_type = "cuda", dtype = dtype):
216
+ for seqlen in [3341, 2048, 349]:
217
+ for random_state in [3407, 42]:
218
+ test_layernorm(
219
+ dim = dim,
220
+ eps = 1e-5,
221
+ dtype = dtype,
222
+ bsz = 21,
223
+ random_state = random_state,
224
+ seqlen = seqlen,
225
+ )
226
+ pass
227
+ pass
228
+ pass
229
+ pass
230
+ pass
231
+ pass
unsloth-main/unsloth-main/unsloth/kernels/rms_layernorm.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import triton
16
+ import triton.language as tl
17
+ import torch
18
+ from .utils import calculate_settings
19
+
20
+
21
+ @triton.jit
22
+ def _rms_layernorm_forward(
23
+ Y, Y_row_stride,
24
+ X, X_row_stride,
25
+ W, W_row_stride,
26
+ r, r_row_stride,
27
+ n_cols, eps,
28
+ BLOCK_SIZE : tl.constexpr
29
+ ):
30
+ """
31
+ Fast RMS Layernorm kernel
32
+ Inspiration from a Triton tutorial:
33
+ https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
34
+ """
35
+ row_idx = tl.program_id(0)
36
+ col_offsets = tl.arange(0, BLOCK_SIZE)
37
+ mask = col_offsets < n_cols
38
+
39
+ Y += row_idx * Y_row_stride
40
+ X += row_idx * X_row_stride
41
+ r += row_idx * r_row_stride
42
+
43
+ X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
44
+ W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32)
45
+
46
+ row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
47
+ inv_var = tl.math.rsqrt(row_var + eps)
48
+ tl.store(r, inv_var)
49
+ normed = X_row * inv_var
50
+ normed = normed.to(W_row.dtype) # Exact copy from HF
51
+ output = normed * W_row
52
+ tl.store(Y + col_offsets, output, mask = mask)
53
+ pass
54
+
55
+
56
+ @triton.heuristics({"GEMMA": lambda args: args["GEMMA"],})
57
+ @triton.jit
58
+ def _rms_layernorm_backward(
59
+ dY, dY_row_stride,
60
+ X, X_row_stride,
61
+ W, W_row_stride,
62
+ r, r_row_stride,
63
+ dW, dW_row_stride,
64
+ n_cols, eps,
65
+ GEMMA : tl.constexpr,
66
+ BLOCK_SIZE : tl.constexpr,
67
+ ):
68
+ """
69
+ Fast RMS Layernorm kernel for the backward pass
70
+ Inspiration from a Triton tutorial:
71
+ https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
72
+ """
73
+ row_idx = tl.program_id(0)
74
+ col_offsets = tl.arange(0, BLOCK_SIZE)
75
+ mask = col_offsets < n_cols
76
+
77
+ dY += row_idx * dY_row_stride
78
+ X += row_idx * X_row_stride
79
+ r += row_idx * r_row_stride
80
+
81
+ dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
82
+ X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
83
+ W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
84
+
85
+ # Get saved row variance
86
+ inv_var = tl.load(r).to(tl.float32)
87
+ normed = X_row * inv_var
88
+
89
+ if GEMMA: dY_W = dY_row * (W_row + 1.0)
90
+ else: dY_W = dY_row * W_row
91
+
92
+ rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)
93
+ output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)
94
+ tl.store(dY + col_offsets, output, mask = mask)
95
+ pass
96
+
97
+
98
+ @triton.jit
99
+ def _gemma_rms_layernorm_forward(
100
+ Y, Y_row_stride,
101
+ X, X_row_stride,
102
+ W, W_row_stride,
103
+ r, r_row_stride,
104
+ n_cols, eps,
105
+ BLOCK_SIZE : tl.constexpr,
106
+ ):
107
+ # Copies https://github.com/google-deepmind/gemma/blob/main/gemma/layers.py#L31
108
+ # and https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L33
109
+ # exactly. Essentially all in float32!
110
+ row_idx = tl.program_id(0)
111
+ col_offsets = tl.arange(0, BLOCK_SIZE)
112
+ mask = col_offsets < n_cols
113
+
114
+ Y += row_idx * Y_row_stride
115
+ X += row_idx * X_row_stride
116
+ r += row_idx * r_row_stride
117
+
118
+ X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
119
+ W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
120
+
121
+ row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
122
+ inv_var = tl.math.rsqrt(row_var + eps)
123
+ tl.store(r, inv_var)
124
+ normed = X_row * inv_var
125
+ output = normed * (W_row + 1.0)
126
+
127
+ tl.store(Y + col_offsets, output, mask = mask)
128
+ pass
129
+
130
+
131
+ class Fast_RMS_Layernorm(torch.autograd.Function):
132
+ @staticmethod
133
+ def forward(ctx, X, W, eps, gemma = False):
134
+ shape = X.shape
135
+ dim = shape[-1]
136
+ X = X.view(-1, dim)
137
+ n_rows, n_cols = X.shape
138
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
139
+
140
+ Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0")
141
+ r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
142
+
143
+ fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
144
+ fx[(n_rows,)](
145
+ Y, Y.stride(0),
146
+ X, X.stride(0),
147
+ W, W.stride(0),
148
+ r, r.stride(0),
149
+ n_cols, eps,
150
+ BLOCK_SIZE = BLOCK_SIZE,
151
+ num_warps = num_warps,
152
+ )
153
+ ctx.eps = eps
154
+ ctx.BLOCK_SIZE = BLOCK_SIZE
155
+ ctx.num_warps = num_warps
156
+ ctx.GEMMA = gemma
157
+ ctx.save_for_backward(X, W, r)
158
+ return Y.view(*shape)
159
+ pass
160
+
161
+ @staticmethod
162
+ def backward(ctx, dY):
163
+ shape = dY.shape
164
+ dim = shape[-1]
165
+ dY = dY.view(-1, dim)
166
+ X, W, r = ctx.saved_tensors
167
+ n_rows, n_cols = dY.shape
168
+ dW = X
169
+
170
+ _rms_layernorm_backward[(n_rows,)](
171
+ dY, dY.stride(0),
172
+ X, X .stride(0),
173
+ W, W .stride(0),
174
+ r, r .stride(0),
175
+ dW, dW.stride(0),
176
+ n_cols, ctx.eps,
177
+ GEMMA = ctx.GEMMA,
178
+ BLOCK_SIZE = ctx.BLOCK_SIZE,
179
+ num_warps = ctx.num_warps,
180
+ )
181
+ dX = dY.view(*shape)
182
+ return dX, None, None, None
183
+ pass
184
+ pass
185
+
186
+
187
+ def fast_rms_layernorm(layernorm, X, gemma = False):
188
+ W = layernorm.weight
189
+ eps = layernorm.variance_epsilon if \
190
+ hasattr(layernorm, "variance_epsilon") \
191
+ else layernorm.eps
192
+ out = Fast_RMS_Layernorm.apply(X, W, eps, gemma)
193
+ return out
194
+ pass
195
+
196
+
197
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
198
+ class Unsloth_LlamaRMSNorm(LlamaRMSNorm):
199
+ def forward(self, X):
200
+ return fast_rms_layernorm(self, X, gemma = False)
201
+ pass
202
+ pass
203
+
204
+ try:
205
+ from transformers.models.mllama.modeling_mllama import MllamaTextRMSNorm
206
+ class Unsloth_MllamaTextRMSNorm(MllamaTextRMSNorm):
207
+ def forward(self, X):
208
+ return fast_rms_layernorm(self, X, gemma = False)
209
+ pass
210
+ pass
211
+ except:
212
+ pass
213
+ pass
214
+
215
+ def patch_rms_layernorm():
216
+ import transformers.models.llama.modeling_llama
217
+ transformers.models.llama.modeling_llama.LlamaRMSNorm = Unsloth_LlamaRMSNorm
218
+ try:
219
+ import transformers.models.mllama.modeling_mllama
220
+ transformers.models.mllama.modeling_mllama.MllamaTextRMSNorm = Unsloth_MllamaTextRMSNorm
221
+ except:
222
+ pass
223
+ return
224
+ pass
225
+
226
+
227
+ def unpatch_rms_layernorm():
228
+ import transformers.models.llama.modeling_llama
229
+ transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
230
+ try:
231
+ import transformers.models.mllama.modeling_mllama
232
+ transformers.models.mllama.modeling_mllama.MllamaTextRMSNorm = MllamaTextRMSNorm
233
+ except:
234
+ pass
235
+ return
236
+ return
237
+ pass
238
+
239
+
240
+ def test_rms_layernorm(
241
+ dim = 1024, eps = 1e-5, dtype = torch.float16,
242
+ bsz = 21, random_state = 3407, seqlen = 3341,
243
+ ):
244
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
245
+ layernorm = LlamaRMSNorm((dim,), eps = eps).to("cuda")
246
+ torch.cuda.manual_seed(random_state)
247
+ torch.manual_seed(random_state)
248
+ torch.nn.init.uniform_(layernorm.weight)
249
+ X = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda")
250
+ XX = X.clone()
251
+ X .requires_grad_(True)
252
+ XX.requires_grad_(True)
253
+ Y = layernorm(X)
254
+ YY = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda", requires_grad = True)
255
+ Y.backward(YY)
256
+ correct_grad = X.grad.clone()
257
+ # from unsloth.kernels import fast_rms_layernorm
258
+ Y = fast_rms_layernorm(layernorm, XX)
259
+ Y.backward(YY)
260
+ assert(torch.amax(correct_grad - XX.grad).item() <= 0.05)
261
+ pass
262
+
263
+
264
+ def testing_suite_layernorm():
265
+ for dim in [512, 1024, 2048]:
266
+ for dtype in [torch.float16, torch.bfloat16]:
267
+ with torch.autocast(device_type = "cuda", dtype = dtype):
268
+ for seqlen in [3341, 2048, 349]:
269
+ for random_state in [3407, 42]:
270
+ test_rms_layernorm(
271
+ dim = dim,
272
+ eps = 1e-5,
273
+ dtype = dtype,
274
+ bsz = 21,
275
+ random_state = random_state,
276
+ seqlen = seqlen,
277
+ )
278
+ pass
279
+ pass
280
+ pass
281
+ pass
282
+ pass
283
+ pass
unsloth-main/unsloth-main/unsloth/kernels/rope_embedding.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import triton
16
+ import triton.language as tl
17
+ import torch
18
+ from .utils import calculate_settings
19
+ ROPE_GROUP_SIZE = 4
20
+
21
+ @triton.heuristics({"BACKWARD_PASS": lambda args: args["BACKWARD_PASS"],})
22
+ @triton.jit
23
+ def _rope_embedding(
24
+ Q, Q_row_stride,
25
+ cos, cos_row_stride,
26
+ sin, sin_row_stride,
27
+ seqlen,
28
+ head_dim : tl.constexpr,
29
+ n_heads : tl.constexpr,
30
+ BACKWARD_PASS : tl.constexpr,
31
+ BLOCK_SIZE : tl.constexpr,
32
+ ):
33
+ """
34
+ Calculates the RoPE Embedding quickly
35
+ RoPE is Q * cos + rotate_half(Q) * sin
36
+ See our blog post for more info
37
+ """
38
+ ROPE_GROUP_SIZE = 4
39
+ row_position = tl.program_id(0)
40
+ group_head_position = tl.program_id(1)
41
+ col_offsets = tl.arange(0, BLOCK_SIZE)
42
+ half_head_dim = head_dim // 2
43
+ mask = col_offsets < half_head_dim
44
+
45
+ sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \
46
+ half_head_dim*0 + col_offsets, mask = mask, other = 0)
47
+ cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \
48
+ half_head_dim*0 + col_offsets, mask = mask, other = 0)
49
+
50
+ if BACKWARD_PASS:
51
+ # See our blog post for more info.
52
+ sin1 = -sin1
53
+ pass
54
+
55
+ # [TODO] Autotune ROPE_GROUP_SIZE to be 1, 2, 4, 8
56
+ head_start = group_head_position * ROPE_GROUP_SIZE
57
+ head_end = min((head_start + ROPE_GROUP_SIZE), n_heads)
58
+
59
+ # 10% Faster kernel from [HuyNguyen-hust](https://github.com/unslothai/unsloth/pull/238)
60
+ for k in range(head_start, head_end):
61
+ offs_q1 = row_position * Q_row_stride + k * head_dim + col_offsets
62
+ offs_q2 = row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim
63
+
64
+ # For Gemma - sometimes RoPE must be done in float32 and not bfloat16
65
+ Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)
66
+ Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)
67
+
68
+ tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)
69
+ tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)
70
+ pass
71
+ pass
72
+
73
+
74
+ class Fast_RoPE_Embedding(torch.autograd.Function):
75
+ @staticmethod
76
+ def forward(ctx, Q, cos, sin):
77
+ cos, sin = cos.squeeze(), sin.squeeze()
78
+ batch, seq_len, n_heads, head_dim = Q.shape
79
+ Q = Q.view(batch*seq_len, n_heads*head_dim)
80
+ n_rows, n_cols = Q.shape
81
+ assert(seq_len <= cos.shape[0])
82
+
83
+ # [TODO] Changing blocksize to head_dim//2 seems to have
84
+ # some concurrency / un-deterministic issues.
85
+ BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2)
86
+
87
+ # group_size = 4 # 4 or 8, too large group_size can hurt performance.
88
+ div, mod = divmod(n_heads, ROPE_GROUP_SIZE)
89
+ n_groups = div + (mod != 0)
90
+
91
+ _rope_embedding[(n_rows, n_groups, )](
92
+ Q, Q.stride(0),
93
+ cos, cos.stride(0),
94
+ sin, sin.stride(0),
95
+ seq_len,
96
+ head_dim, n_heads,
97
+ BACKWARD_PASS = False,
98
+ BLOCK_SIZE = BLOCK_SIZE,
99
+ num_warps = num_warps,
100
+ )
101
+ ctx.BLOCK_SIZE = BLOCK_SIZE
102
+ ctx.num_warps = num_warps
103
+ ctx.n_groups = n_groups
104
+ ctx.cos = cos
105
+ ctx.sin = sin
106
+ return Q.view(batch, seq_len, n_heads, head_dim)
107
+ pass
108
+
109
+ @staticmethod
110
+ def backward(ctx, dY):
111
+ batch, seq_len, n_heads, head_dim = dY.shape
112
+ dY = dY.reshape(batch*seq_len, n_heads*head_dim)
113
+ # Must be reshape not view
114
+ n_rows, n_cols = dY.shape
115
+
116
+ cos = ctx.cos
117
+ sin = ctx.sin
118
+
119
+ _rope_embedding[(n_rows, ctx.n_groups, )](
120
+ dY, dY .stride(0),
121
+ cos, cos.stride(0),
122
+ sin, sin.stride(0),
123
+ seq_len, head_dim, n_heads,
124
+ BACKWARD_PASS = True,
125
+ BLOCK_SIZE = ctx.BLOCK_SIZE,
126
+ num_warps = ctx.num_warps,
127
+ )
128
+ dY = dY.view(batch, seq_len, n_heads, head_dim)
129
+ return dY, None, None,
130
+ pass
131
+ pass
132
+
133
+
134
+ def fast_rope_embedding(Q, K, cos, sin):
135
+ Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2)
136
+ K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2)
137
+ return Q, K
138
+ pass
139
+
140
+
141
+ class Slow_RoPE_Embedding(torch.autograd.Function):
142
+ @staticmethod
143
+ def forward(ctx, Q, cos, sin, position_ids):
144
+ if position_ids is not None:
145
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
146
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
147
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
148
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
149
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
150
+
151
+ # Q * cos + rotate_half(Q) * sin
152
+ half = Q.shape[-1]//2
153
+ RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1)
154
+ Q *= cos
155
+ Q.addcmul_(RH_Q, sin)
156
+ # RH_Q *= sin
157
+ # Q += RH_Q
158
+ ctx.save_for_backward(cos, sin)
159
+ return Q
160
+ pass
161
+
162
+ @staticmethod
163
+ def backward(ctx, dY):
164
+ cos, sin = ctx.saved_tensors
165
+ # Q * cos + rotate_half.T(Q) * sin
166
+ half = dY.shape[-1]//2
167
+ RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1)
168
+ dY *= cos
169
+ dY.addcmul_(RH_dY, sin)
170
+ # RH_dY *= sin
171
+ # dY += RH_dY
172
+ return dY, None, None, None
173
+ pass
174
+ pass
175
+
176
+
177
+ def inplace_rope_embedding(Q, K, cos, sin, position_ids):
178
+ Q = Slow_RoPE_Embedding.apply(Q, cos, sin, position_ids)
179
+ K = Slow_RoPE_Embedding.apply(K, cos, sin, position_ids)
180
+ return Q, K
181
+ pass
unsloth-main/unsloth-main/unsloth/kernels/swiglu.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import triton
16
+ import triton.language as tl
17
+ import torch
18
+ from .utils import calculate_settings
19
+
20
+
21
+ @triton.jit
22
+ def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
23
+ block_idx = tl.program_id(0)
24
+ offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
25
+ mask = offsets < n_elements
26
+
27
+ e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
28
+ g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
29
+
30
+ # f = e * sigmoid(e)
31
+ f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row))
32
+ f_row = f_row.to(g_row.dtype) # Exact copy from HF
33
+ # h = f * g
34
+ h_row = f_row * g_row
35
+
36
+ # Store h
37
+ tl.store(h + offsets, h_row, mask = mask)
38
+ pass
39
+
40
+
41
+ def swiglu_fg_kernel(e, g):
42
+ batch, seq_len, hd = e.shape
43
+ n_elements = e.numel()
44
+ h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = "cuda:0")
45
+ grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
46
+ _fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,)
47
+ return h
48
+ pass
49
+
50
+
51
+ @triton.jit
52
+ def _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
53
+ """
54
+ e = e.float()
55
+ se = 1.0 / (1.0 + torch.exp(-e))
56
+ f = (se * e).to(dtype)
57
+ h = f * g
58
+ df = DW * f
59
+ dg = DW * g
60
+ de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
61
+ """
62
+ block_idx = tl.program_id(0)
63
+ offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
64
+ mask = offsets < n_elements
65
+
66
+ DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
67
+ e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
68
+ g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
69
+
70
+ # e = e.float()
71
+ # se = 1.0 / (1.0 + torch.exp(-e))
72
+ se_row = tl.sigmoid(e_row) # 1.0 / (1.0 + tl.exp(-e_row))
73
+ # f = (se * e).to(dtype)
74
+ f_row = se_row * e_row
75
+ f_row = f_row.to(DW_row.dtype)
76
+ # h = f * g
77
+ h_row = f_row * g_row
78
+ # df = DW * f
79
+ df_row = DW_row * f_row
80
+ # dg = DW * g
81
+ dg_row = DW_row * g_row
82
+ # de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
83
+ de_row = dg_row.to(tl.float32) * se_row * (1.0 + e_row * (1.0 - se_row))
84
+ de_row = de_row.to(DW_row.dtype)
85
+
86
+ # Store derivatives in buffers
87
+ tl.store(DW + offsets, h_row, mask = mask) # h = f * g
88
+ tl.store(e + offsets, df_row, mask = mask) # df = DW * f
89
+ tl.store(g + offsets, de_row, mask = mask) # de
90
+ pass
91
+
92
+
93
+ def swiglu_DWf_DW_dfg_kernel(DW, e, g):
94
+ batch_seq_len, hd = e.shape
95
+ n_elements = e.numel()
96
+ grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
97
+ _DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
98
+ return DW, e, g
99
+ pass
unsloth-main/unsloth-main/unsloth/kernels/utils.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import triton
16
+ MAX_FUSED_SIZE = 65536
17
+ next_power_of_2 = triton.next_power_of_2
18
+
19
+ # torch.cuda.amp.custom_fwd is deprecated >= 2.4
20
+ import torch
21
+ from packaging.version import Version
22
+ if Version(torch.__version__) < Version("2.4.0"):
23
+ torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
24
+ torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
25
+ else:
26
+ torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
27
+ torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
28
+ pass
29
+
30
+
31
+ # tl.math.tanh now is libdevice.tanh
32
+ from packaging.version import Version
33
+ import triton
34
+ if Version(triton.__version__) >= Version("3.0.0"):
35
+ from triton.language.extra import libdevice
36
+ triton_tanh = libdevice.tanh
37
+ else:
38
+ import triton.language as tl
39
+ triton_tanh = tl.math.tanh
40
+ pass
41
+
42
+
43
+ def calculate_settings(n):
44
+ BLOCK_SIZE = next_power_of_2(n)
45
+ if BLOCK_SIZE > MAX_FUSED_SIZE:
46
+ raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
47
+ f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
48
+ num_warps = 4
49
+ if BLOCK_SIZE >= 32768: num_warps = 32
50
+ elif BLOCK_SIZE >= 8192: num_warps = 16
51
+ elif BLOCK_SIZE >= 2048: num_warps = 8
52
+ return BLOCK_SIZE, num_warps
53
+ pass
54
+
55
+
56
+ import bitsandbytes as bnb
57
+ # https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1330/files
58
+ HAS_CUDA_STREAM = Version(bnb.__version__) > Version("0.43.3")
59
+ global CUDA_STREAM
60
+ CUDA_STREAM = None
61
+ get_ptr = bnb.functional.get_ptr
62
+ import ctypes
63
+ cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
64
+ cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
65
+ cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
66
+ cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
67
+ cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16
68
+
69
+
70
+ def QUANT_STATE(W):
71
+ return getattr(W, "quant_state", None)
72
+ pass
73
+
74
+
75
+ def get_lora_parameters(proj):
76
+ # For DPO or disabled adapters
77
+ base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj)
78
+ W = base_layer.weight
79
+
80
+ if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
81
+ return W, QUANT_STATE(W), None, None, None
82
+ pass
83
+
84
+ active_adapter = proj.active_adapters[0] if \
85
+ hasattr(proj, "active_adapters") else proj.active_adapter
86
+ A = proj.lora_A [active_adapter].weight
87
+ B = proj.lora_B [active_adapter].weight
88
+ s = proj.scaling[active_adapter]
89
+ return W, QUANT_STATE(W), A, B, s
90
+ pass
91
+
92
+
93
+ def get_lora_parameters_bias(proj):
94
+ # For DPO or disabled adapters
95
+ base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj)
96
+ W = base_layer.weight
97
+ bias = base_layer.bias
98
+
99
+ if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
100
+ return W, QUANT_STATE(W), None, None, None, bias
101
+ pass
102
+
103
+ active_adapter = proj.active_adapters[0] if \
104
+ hasattr(proj, "active_adapters") else proj.active_adapter
105
+ A = proj.lora_A [active_adapter].weight
106
+ B = proj.lora_B [active_adapter].weight
107
+ s = proj.scaling[active_adapter]
108
+ return W, QUANT_STATE(W), A, B, s, bias
109
+ pass
110
+
111
+
112
+ if HAS_CUDA_STREAM:
113
+ def fast_dequantize(W, quant_state = None, out = None):
114
+ if quant_state is None: return W
115
+ if type(quant_state) is not list:
116
+ # New quant_state as a class
117
+ # https://github.com/TimDettmers/bitsandbytes/pull/763/files
118
+ absmax = quant_state.absmax
119
+ shape = quant_state.shape
120
+ dtype = quant_state.dtype
121
+ blocksize = quant_state.blocksize
122
+ offset = quant_state.offset
123
+ state2 = quant_state.state2
124
+ absmax2 = state2.absmax
125
+ code2 = state2.code
126
+ blocksize2 = state2.blocksize
127
+ else:
128
+ # Old quant_state as a list of lists
129
+ absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
130
+ offset, state2 = compressed_stats
131
+ absmax2, code2, blocksize2, _, _, _, _ = state2
132
+ pass
133
+ global CUDA_STREAM
134
+ if CUDA_STREAM is None: CUDA_STREAM = torch.cuda.current_stream("cuda:0")
135
+
136
+ # Create weight matrix
137
+ if out is None:
138
+ out = torch.empty(shape, dtype = dtype, device = "cuda:0")
139
+ else:
140
+ assert(out.shape == shape)
141
+ assert(out.dtype == dtype)
142
+
143
+ # NF4 dequantization of statistics
144
+ n_elements_absmax = absmax.numel()
145
+ out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0")
146
+
147
+ # Do dequantization
148
+ ptr_out_absmax = get_ptr(out_absmax)
149
+ cdequantize_blockwise_fp32(
150
+ get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
151
+ ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax), CUDA_STREAM,
152
+ )
153
+ out_absmax += offset
154
+
155
+ fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
156
+ cdequantize_blockwise_bf16_nf4
157
+ fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
158
+ ctypes.c_int(blocksize), ctypes.c_int(out.numel()), CUDA_STREAM,)
159
+
160
+ # Careful returning transposed data
161
+ is_transposed = (True if W.shape[0] == 1 else False)
162
+ return out.t() if is_transposed else out
163
+ pass
164
+ else:
165
+ def fast_dequantize(W, quant_state = None, out = None):
166
+ if quant_state is None: return W
167
+ if type(quant_state) is not list:
168
+ # New quant_state as a class
169
+ # https://github.com/TimDettmers/bitsandbytes/pull/763/files
170
+ absmax = quant_state.absmax
171
+ shape = quant_state.shape
172
+ dtype = quant_state.dtype
173
+ blocksize = quant_state.blocksize
174
+ offset = quant_state.offset
175
+ state2 = quant_state.state2
176
+ absmax2 = state2.absmax
177
+ code2 = state2.code
178
+ blocksize2 = state2.blocksize
179
+ else:
180
+ # Old quant_state as a list of lists
181
+ absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
182
+ offset, state2 = compressed_stats
183
+ absmax2, code2, blocksize2, _, _, _, _ = state2
184
+ pass
185
+
186
+ # Create weight matrix
187
+ if out is None:
188
+ out = torch.empty(shape, dtype = dtype, device = "cuda:0")
189
+ else:
190
+ assert(out.shape == shape)
191
+ assert(out.dtype == dtype)
192
+
193
+ # NF4 dequantization of statistics
194
+ n_elements_absmax = absmax.numel()
195
+ out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0")
196
+
197
+ # Do dequantization
198
+ ptr_out_absmax = get_ptr(out_absmax)
199
+ cdequantize_blockwise_fp32(
200
+ get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
201
+ ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax),
202
+ )
203
+ out_absmax += offset
204
+
205
+ fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
206
+ cdequantize_blockwise_bf16_nf4
207
+ fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
208
+ ctypes.c_int(blocksize), ctypes.c_int(out.numel()),)
209
+
210
+ # Careful returning transposed data
211
+ is_transposed = (True if W.shape[0] == 1 else False)
212
+ return out.t() if is_transposed else out
213
+ pass
214
+ pass
215
+
216
+
217
+ if HAS_CUDA_STREAM:
218
+ def fast_gemv(X, W, quant_state, out = None):
219
+ if quant_state is None: return torch.matmul(X, W, out = out)
220
+ # For fast X @ W where seq_len == 1
221
+ # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
222
+ _, q_len, hd = X.shape
223
+ # assert(q_len == 1)
224
+
225
+ if type(quant_state) is not list:
226
+ # https://github.com/TimDettmers/bitsandbytes/pull/763/files
227
+ absmax = quant_state.absmax
228
+ shape = quant_state.shape
229
+ dtype = quant_state.dtype
230
+ blocksize = quant_state.blocksize
231
+ stats = quant_state.code
232
+ offset = quant_state.offset
233
+ state2 = quant_state.state2
234
+ absmax2 = state2.absmax
235
+ code2 = state2.code
236
+ blocksize2 = state2.blocksize
237
+ else:
238
+ absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
239
+ offset, state2 = compressed_stats
240
+ absmax2, code2, blocksize2, _, _, _, _ = state2
241
+ pass
242
+ global CUDA_STREAM
243
+ if CUDA_STREAM is None: CUDA_STREAM = torch.cuda.current_stream("cuda:0")
244
+
245
+ # assert(dtype == X.dtype)
246
+ bout = shape[0]
247
+
248
+ if out is None:
249
+ out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda:0")
250
+ # else:
251
+ # assert(out.shape == (1, 1, bout,))
252
+ # pass
253
+
254
+ n = 1
255
+ m = shape[0]
256
+ k = shape[1]
257
+ lda = shape[0]
258
+ ldc = shape[0]
259
+ ldb = (hd+1)//2
260
+ m = ctypes.c_int32(m)
261
+ n = ctypes.c_int32(n)
262
+ k = ctypes.c_int32(k)
263
+ lda = ctypes.c_int32(lda)
264
+ ldb = ctypes.c_int32(ldb)
265
+ ldc = ctypes.c_int32(ldc)
266
+
267
+ df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0")
268
+ cdequantize_blockwise_fp32(
269
+ get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
270
+ ctypes.c_int(blocksize2), ctypes.c_int(df.numel()), CUDA_STREAM,
271
+ )
272
+ df += offset
273
+ absmax = df
274
+
275
+ fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \
276
+ cgemm_4bit_inference_naive_bf16
277
+
278
+ blocksize = ctypes.c_int32(blocksize)
279
+ fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
280
+ lda, ldb, ldc, blocksize, CUDA_STREAM,)
281
+
282
+ return out
283
+ pass
284
+ else:
285
+ def fast_gemv(X, W, quant_state, out = None):
286
+ if quant_state is None: return torch.matmul(X, W, out = out)
287
+ # For fast X @ W where seq_len == 1
288
+ # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
289
+ _, q_len, hd = X.shape
290
+ # assert(q_len == 1)
291
+
292
+ if type(quant_state) is not list:
293
+ # https://github.com/TimDettmers/bitsandbytes/pull/763/files
294
+ absmax = quant_state.absmax
295
+ shape = quant_state.shape
296
+ dtype = quant_state.dtype
297
+ blocksize = quant_state.blocksize
298
+ stats = quant_state.code
299
+ offset = quant_state.offset
300
+ state2 = quant_state.state2
301
+ absmax2 = state2.absmax
302
+ code2 = state2.code
303
+ blocksize2 = state2.blocksize
304
+ else:
305
+ absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
306
+ offset, state2 = compressed_stats
307
+ absmax2, code2, blocksize2, _, _, _, _ = state2
308
+ pass
309
+ # assert(dtype == X.dtype)
310
+ bout = shape[0]
311
+
312
+ if out is None:
313
+ out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda:0")
314
+ # else:
315
+ # assert(out.shape == (1, 1, bout,))
316
+ # pass
317
+
318
+ n = 1
319
+ m = shape[0]
320
+ k = shape[1]
321
+ lda = shape[0]
322
+ ldc = shape[0]
323
+ ldb = (hd+1)//2
324
+ m = ctypes.c_int32(m)
325
+ n = ctypes.c_int32(n)
326
+ k = ctypes.c_int32(k)
327
+ lda = ctypes.c_int32(lda)
328
+ ldb = ctypes.c_int32(ldb)
329
+ ldc = ctypes.c_int32(ldc)
330
+
331
+ df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0")
332
+ cdequantize_blockwise_fp32(
333
+ get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
334
+ ctypes.c_int(blocksize2), ctypes.c_int(df.numel()),
335
+ )
336
+ df += offset
337
+ absmax = df
338
+
339
+ fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \
340
+ cgemm_4bit_inference_naive_bf16
341
+
342
+ blocksize = ctypes.c_int32(blocksize)
343
+ fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
344
+ lda, ldb, ldc, blocksize,)
345
+
346
+ return out
347
+ pass
348
+ pass
349
+
350
+
351
+ def fast_linear_forward(proj, X, temp_lora = None, out = None):
352
+
353
+ W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj)
354
+ bsz, q_len, in_dim = X.shape
355
+ if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)
356
+
357
+ if W_quant is None:
358
+ out = torch.matmul(X, W.t(), out = out)
359
+ elif bsz == 1 and q_len == 1:
360
+ out = fast_gemv(X, W, W_quant, out = out)
361
+ else:
362
+ W = fast_dequantize(W.t(), W_quant)
363
+ out = torch.matmul(X, W, out = out)
364
+ pass
365
+
366
+ # Add in LoRA weights
367
+ if lora_A is not None:
368
+ out_dim = out.shape[2]
369
+ dtype = X.dtype
370
+
371
+ if not hasattr(lora_A, "_fast_lora"):
372
+ lora_A._fast_lora = lora_A.to(dtype)
373
+ lora_B._fast_lora = lora_B.to(dtype)
374
+ pass
375
+
376
+ if bsz == 1:
377
+ out = out.view(out_dim)
378
+ temp_lora = torch.mv(lora_A._fast_lora, X.ravel(), out = temp_lora)
379
+ out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S)
380
+ else:
381
+ out = out.view(bsz, out_dim)
382
+ temp_lora = torch.mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora)
383
+ out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S)
384
+ pass
385
+ out = out.view(bsz, 1, out_dim)
386
+ pass
387
+
388
+ if bias is not None: out += bias
389
+
390
+ return out
391
+ pass
392
+
393
+
394
+ def matmul_lora(X, W, W_quant, A, B, s, out = None):
395
+ dtype = X.dtype
396
+ W = fast_dequantize(W.t(), W_quant)
397
+
398
+ if X.dim() == 3:
399
+ batch, seq_len, d = X.shape
400
+ X = X.view(-1, X.shape[-1])
401
+ reshape = True
402
+ else:
403
+ reshape = False
404
+ pass
405
+
406
+ out = torch.matmul(X, W, out = out)
407
+ if W_quant is not None: del W
408
+
409
+ if A is not None:
410
+ # LoRA is enabled
411
+ A, B = A.t(), B.t()
412
+ out += (X @ A.to(dtype)) @ (s * B.to(dtype))
413
+ pass
414
+
415
+ return out.view(batch, seq_len, -1) if reshape else out
416
+ pass
unsloth-main/unsloth-main/unsloth/models/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .loader import FastLanguageModel
16
+ from .llama import FastLlamaModel
17
+ from .mistral import FastMistralModel
18
+ from .qwen2 import FastQwen2Model
19
+ from .dpo import PatchDPOTrainer
20
+ from ._utils import is_bfloat16_supported
unsloth-main/unsloth-main/unsloth/models/_utils.py ADDED
@@ -0,0 +1,1140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ __version__ = "2024.9.post4"
16
+
17
+ __all__ = [
18
+ "prepare_model_for_kbit_training",
19
+ "xformers",
20
+ "xformers_attention",
21
+ "xformers_version",
22
+ "__version__",
23
+ "HAS_FLASH_ATTENTION",
24
+ "HAS_FLASH_ATTENTION_SOFTCAPPING",
25
+ "PRE_CHECK",
26
+ "platform_system",
27
+ "patch_tokenizer",
28
+ "get_statistics",
29
+ "Unsloth_Offloaded_Gradient_Checkpointer",
30
+ "offload_to_disk",
31
+ "offload_input_embeddings",
32
+ "offload_output_embeddings",
33
+ "is_bfloat16_supported",
34
+ "unsloth_offloaded_gradient_checkpoint",
35
+ "torch_compile_options",
36
+ "patch_linear_scaling",
37
+ "patch_llama_rope_scaling",
38
+ "check_nvidia",
39
+ "create_boolean_mask",
40
+ "torch_amp_custom_fwd",
41
+ "torch_amp_custom_bwd",
42
+ "accelerate_old_send_to_device",
43
+ "accelerate_new_send_to_device",
44
+ "patch_gradient_checkpointing",
45
+ "unpatch_gradient_checkpointing",
46
+ ]
47
+
48
+ import torch
49
+ from typing import Union, Optional, List, Any, Callable, Tuple
50
+ from platform import system as platform_system
51
+ platform_system = platform_system()
52
+ import numpy as np
53
+ import warnings, subprocess, re, inspect, psutil, os, math
54
+ from packaging.version import Version
55
+
56
+ # =============================================
57
+ # Disable some warnings which can get annoying
58
+ warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
59
+ warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub")
60
+ warnings.filterwarnings(action = "ignore", category = UserWarning, module = "trl")
61
+ warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "huggingface_hub")
62
+ warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "xformers")
63
+ warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "subprocess")
64
+ warnings.filterwarnings(action = "ignore", category = UserWarning, module = "transformers")
65
+ warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "accelerate")
66
+ warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocessing")
67
+ warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocess")
68
+
69
+ # Stop "Special tokens have been added in the vocabulary, ..."
70
+ import logging
71
+ logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL+1)
72
+ # =============================================
73
+
74
+ # =============================================
75
+ # Edits all Config files to enable RoPE Scaling for all models
76
+
77
+ # Transformers had to update for Mistral Nemo 12b since Attention is (5120, 4096) now.
78
+ def patch_mistral_nemo_config(config):
79
+ if "head_dim (" not in config:
80
+ add_head_dim = "If it is not specified, will default to `8`.\n"\
81
+ " head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):\n"\
82
+ " The attention head dimension."
83
+ config = config.replace("If it is not specified, will default to `8`.", add_head_dim)
84
+
85
+ add_head_dim = "num_key_value_heads=8,\n head_dim=None,"
86
+ config = config.replace("num_key_value_heads=8,", add_head_dim)
87
+
88
+ add_head_dim = "self.sliding_window = sliding_window\n self.head_dim = head_dim or hidden_size // num_attention_heads\n"
89
+ config = config.replace("self.sliding_window = sliding_window", add_head_dim)
90
+ pass
91
+ return config
92
+ pass
93
+
94
+ from transformers import __version__ as transformers_version
95
+ from transformers import PretrainedConfig
96
+ model_architectures = ["llama", "mistral", "gemma", "gemma2", "qwen2",]
97
+
98
+ for model_name in model_architectures:
99
+ config_filepath = f"transformers.models.{model_name}.configuration_{model_name}"
100
+ model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
101
+ config_filename = f"{model_name.title()}Config"
102
+ exec(f"from {config_filepath} import {config_filename}", globals())
103
+
104
+ try:
105
+ config = inspect.getsource(eval(config_filename))
106
+ except:
107
+ continue
108
+ if "rope_scaling" in config: continue
109
+ config = re.sub(
110
+ r"(\*\*kwargs)[\s]{0,}\,[\s]{0,}\)[\s]{0,}\:",
111
+ r"rope_scaling=None,"\
112
+ r"\n **kwargs):\n"\
113
+ r"\n self.rope_scaling = rope_scaling\n",
114
+ config,
115
+ )
116
+
117
+ # Just for Mistral Nemo
118
+ if model_name == "mistral":
119
+ if Version(transformers_version) <= Version("4.42.4"):
120
+ config = patch_mistral_nemo_config(config)
121
+ pass
122
+
123
+ exec(config, globals())
124
+ exec(f"import {config_filepath}", globals())
125
+ exec(f"{config_filepath}.{config_filename} = {config_filename}", globals())
126
+ pass
127
+ # =============================================
128
+
129
+ # =============================================
130
+ # torch.cuda.amp.custom_fwd is deprecated >= 2.4
131
+ import torch
132
+ torch_version = torch.__version__
133
+ if Version(torch_version) < Version("2.4.0"):
134
+ torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
135
+ torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
136
+ else:
137
+ torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
138
+ torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
139
+ pass
140
+ # =============================================
141
+
142
+ # =============================================
143
+ # Fix KeyError: 'Cache only has 0 layers, attempted to access layer with index 0'
144
+ import transformers.cache_utils
145
+ if hasattr(transformers.cache_utils, "DynamicCache") and \
146
+ transformers.cache_utils.DynamicCache.__getitem__.__name__ != "__cache_utils_getitem__":
147
+
148
+ source = inspect.getsource(transformers.cache_utils.DynamicCache.__getitem__)
149
+ start = source.find("def")
150
+ spaces = start*" "
151
+ source = source.split("\n")
152
+ source = "\n".join(x[start:] for x in source)
153
+ where = source.find("raise KeyError")
154
+ source = source[:where] + \
155
+ f"if len(self) == 0:\n{spaces}{spaces}"\
156
+ " raise RuntimeError('Unsloth: You must call `FastLanguageModel.for_inference(model)` before doing inference for Unsloth models.')\n" + \
157
+ f"{spaces}{spaces}else:\n{spaces}{spaces}{spaces}" + source[where:]
158
+ source = source.replace("__getitem__", "__cache_utils_getitem__", 1)
159
+ exec(source)
160
+ transformers.cache_utils.DynamicCache.__getitem__ = __cache_utils_getitem__
161
+ pass
162
+ # =============================================
163
+
164
+ # =============================================
165
+ # Get Flash Attention v2 if Ampere (RTX 30xx, A100)
166
+ import bitsandbytes as bnb
167
+ from transformers import AutoTokenizer
168
+ from transformers.utils.import_utils import _is_package_available
169
+
170
+ major_version, minor_version = torch.cuda.get_device_capability()
171
+ SUPPORTS_BFLOAT16 = False
172
+ HAS_FLASH_ATTENTION = False
173
+ HAS_FLASH_ATTENTION_SOFTCAPPING = False
174
+
175
+ if major_version >= 8:
176
+ SUPPORTS_BFLOAT16 = True
177
+ if _is_package_available("flash_attn"):
178
+ # Check for CUDA linking errors "undefined symbol: _ZNK3c106SymIntltEl"
179
+ try:
180
+ from flash_attn.flash_attn_interface import flash_attn_cuda
181
+ HAS_FLASH_ATTENTION = True
182
+
183
+ # Also check for softcapping
184
+ from flash_attn import __version__ as flash_attn_version
185
+ HAS_FLASH_ATTENTION_SOFTCAPPING = Version(flash_attn_version) >= Version("2.6.3")
186
+ if not HAS_FLASH_ATTENTION_SOFTCAPPING:
187
+ print(
188
+ "Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"\
189
+ "Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"\
190
+ "To update flash-attn, do the below:\n"\
191
+ '\npip install --no-deps --upgrade "flash-attn>=2.6.3"'
192
+ )
193
+ except:
194
+ print(
195
+ "Unsloth: Your Flash Attention 2 installation seems to be broken?\n"\
196
+ "A possible explanation is you have a new CUDA version which isn't\n"\
197
+ "yet compatible with FA2? Please file a ticket to Unsloth or FA2.\n"\
198
+ "We shall now use Xformers instead, which does not have any performance hits!\n"\
199
+ "We found this negligible impact by benchmarking on 1x A100."
200
+ )
201
+
202
+ # Stop Flash Attention from importing!
203
+ import transformers.utils.import_utils
204
+ transformers.utils.import_utils.is_flash_attn_2_available = lambda *args, **kwargs: False
205
+ import transformers.utils
206
+ transformers.utils.is_flash_attn_2_available = lambda *args, **kwargs: False
207
+
208
+ HAS_FLASH_ATTENTION = False
209
+ pass
210
+ else:
211
+ HAS_FLASH_ATTENTION = False
212
+ else:
213
+ # Tri Dao's benchmark shows xformers is faster for now.
214
+ HAS_FLASH_ATTENTION = False
215
+ pass
216
+
217
+ from transformers.models.llama.modeling_llama import logger
218
+
219
+ # =============================================
220
+ # Get Xformers
221
+ from xformers import __version__ as xformers_version
222
+ # Temporarily disable 0.0.27 and higher - inference issues
223
+ if False: #Version(xformers_version) >= Version("0.0.27"):
224
+ raise ImportError(
225
+ "Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below "\
226
+ "then press Disconnect Runtime and then Restart it.\n"\
227
+ "\n"\
228
+ "%%capture\n"
229
+ "# Installs Unsloth, Xformers (Flash Attention) and all other packages!\n"
230
+ '!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'
231
+ '!pip install --no-deps "xformers<=0.0.27" trl peft accelerate bitsandbytes\n'\
232
+ '\n'\
233
+ f"Otherwise in local machines, your xformers version of {xformers_version} is too new.\n"\
234
+ 'Please downgrade xformers via `pip install --force-reinstall "xformers<=0.0.27"'
235
+ )
236
+ pass
237
+
238
+ if Version(torch_version) < Version("2.2.0") and Version(xformers_version) >= Version("0.0.24"):
239
+ raise ImportError(
240
+ f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"\
241
+ f"Please install xformers < 0.0.24 for torch = {torch_version}."
242
+ )
243
+ elif Version(torch_version) < Version("2.3.0") and Version(xformers_version) >= Version("0.0.26"):
244
+ raise ImportError(
245
+ f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"\
246
+ f"Please install xformers < 0.0.26 for torch = {torch_version}."
247
+ )
248
+ elif Version(torch_version) < Version("2.4.0") and Version(xformers_version) > Version("0.0.27"):
249
+ raise ImportError(
250
+ f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"\
251
+ f"Please install xformers <= 0.0.27 for torch = {torch_version}."
252
+ )
253
+ pass
254
+
255
+ from xformers._cpp_lib import _register_extensions
256
+ try:
257
+ _register_extensions() # Check if C++ modules are loaded correctly
258
+ except Exception as error:
259
+ raise ImportError(
260
+ "Unsloth: Xformers was not installed correctly.\n"\
261
+ "Please install xformers separately first.\n"\
262
+ "Then confirm if it's correctly installed by running:\n"\
263
+ "python -m xformers.info\n\n"
264
+ "Longer error message:\n" + str(error)
265
+ )
266
+ pass
267
+ import xformers.ops.fmha as xformers
268
+ xformers_attention = xformers.memory_efficient_attention
269
+
270
+ # Check TRL version
271
+ from trl import __version__ as trl_version
272
+ # Unsloth now supports all TRL versions!
273
+ if False:#Version(trl_version) >= Version("0.9.0"):
274
+ raise ImportError(
275
+ "Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below "\
276
+ "then press Disconnect Runtime and then Restart it.\n"\
277
+ "\n"\
278
+ "%%capture\n"
279
+ "# Installs Unsloth, Xformers (Flash Attention) and all other packages!\n"
280
+ '!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'
281
+ '!pip install --no-deps "xformers<=0.0.27" trl peft accelerate bitsandbytes\n'\
282
+ '\n'\
283
+ f"Otherwise in local machines, your TRL version of {trl_version} is too new.\n"\
284
+ 'Please downgrade TRL via `pip install --force-reinstall trl'
285
+ )
286
+ pass
287
+
288
+ # =============================================
289
+ # Fix new Xformers versions TypeError: Multiple dispatch failed for 'torch._ops.aten.to.dtype_layout'
290
+ accelerate_old_send_to_device = None
291
+ accelerate_new_send_to_device = None
292
+ if Version(xformers_version) >= Version("0.0.27"):
293
+ import accelerate.utils.operations
294
+ if hasattr(accelerate.utils.operations, "send_to_device") and \
295
+ accelerate.utils.operations.send_to_device.__name__ != "_fixed_send_to_device":
296
+ accelerate_old_send_to_device = accelerate.utils.operations.send_to_device
297
+ from accelerate.utils.operations import *
298
+ send_to_device = inspect.getsource(accelerate.utils.operations.send_to_device)
299
+ send_to_device = re.sub(
300
+ r"([ ]{4,})return tensor\.to\(device\)",
301
+ r"\1try: return tensor.to(device)\n\1except: return tensor",
302
+ send_to_device,
303
+ ).replace("def send_to_device", "def _fixed_send_to_device")
304
+ exec(send_to_device)
305
+ # accelerate.utils.operations.send_to_device = _fixed_send_to_device
306
+ accelerate_new_send_to_device = _fixed_send_to_device
307
+ pass
308
+ pass
309
+
310
+ # Transformers 4.46 breaks dynamic caching. This is a hack
311
+ import transformers.generation.configuration_utils
312
+ if hasattr(transformers.generation.configuration_utils, "ALL_CACHE_IMPLEMENTATIONS"):
313
+ if type(transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS) is list:
314
+ transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS.append("dynamic")
315
+ pass
316
+ pass
317
+ # =============================================
318
+
319
+ # =============================================
320
+ # Torch compile settings
321
+
322
+ # Just remove max_autotune_gemm warning
323
+ import functools
324
+ @functools.lru_cache(None)
325
+ def is_big_gpu(index):
326
+ sms = torch.cuda.get_device_properties(index).multi_processor_count
327
+ if sms < 80: # V100
328
+ # log.warning("not enough SMs to use max_autotune_gemm mode")
329
+ return False
330
+ return True
331
+ import torch._inductor.utils
332
+ torch._inductor.utils.is_big_gpu = is_big_gpu
333
+
334
+
335
+ # Torch compile arguments
336
+ torch_compile_arguments = [
337
+ "config.dce = True",
338
+ "config.memory_planning = True",
339
+ "config.memory_pool = 'combined'",
340
+ "config.coordinate_descent_tuning = True",
341
+ "config.max_autotune_gemm = False", # GEMM is unnecessary
342
+ "config.autotune_multi_device = False",
343
+ "config.max_autotune_gemm_backends = 'TRITON,ATEN,CPP'", # Not much faster
344
+ "config.aggressive_fusion = False", # Careful changes results!
345
+ "config.cuda.enable_cuda_lto = True",
346
+ "config.cuda.use_fast_math = True",
347
+ "config.cuda.compile_opt_level = '-O2'",
348
+ ]
349
+ # Torch dynamo arguments
350
+ torch_dynamo_arguments = [
351
+ "config.accumulated_cache_size_limit = 1024", # Bump up a bit from 256
352
+ "config.suppress_errors = True", # Supress errors for now
353
+ "config.do_not_emit_runtime_asserts = True",
354
+ "config.cache_size_limit = 1024", # Flex Attention
355
+ ]
356
+ import torch._inductor.config as config
357
+ for _try_compile_argument in torch_compile_arguments:
358
+ try: exec(_try_compile_argument)
359
+ except: pass
360
+ pass
361
+ import torch._dynamo.config as config
362
+ for _try_dynamo_argument in torch_dynamo_arguments:
363
+ try: exec(_try_dynamo_argument)
364
+ except: pass
365
+ pass
366
+ torch_compile_options = {
367
+ "epilogue_fusion" : True,
368
+ "max_autotune" : True,
369
+ "shape_padding" : True,
370
+ "trace.enabled" : False, # Output Triton kernel outputs!
371
+ "triton.cudagraphs" : False,
372
+ }
373
+ # =============================================
374
+
375
+ def prepare_model_for_kbit_training(
376
+ model : Any,
377
+ use_gradient_checkpointing : Optional = True,
378
+ use_reentrant : Optional[bool] = True,
379
+ ) -> Any:
380
+ """
381
+ Calculates where to place the gradient checkpoints given n_layers.
382
+ We also freeze all other layers's gradients
383
+
384
+ Args:
385
+ model: Any LlamaModel with layers.
386
+ use_gradient_checkpointing (`bool`, *optional*):
387
+ Default enabled. Provides memory savings by not saving all activations,
388
+ but only some.
389
+ use_reentrant (`bool`, *optional*):
390
+ https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L354
391
+ Optimal gradient checkpointing algorithm which will be the default in
392
+ future Pytorch versions.
393
+ """
394
+
395
+ # Freeze all parameters except LoRA
396
+ with torch.no_grad():
397
+ for name, param in model.named_parameters():
398
+ if ".lora_A." in name or ".lora_B." in name or ".lora_magnitude_vector" in name:
399
+ param.requires_grad_(True)
400
+ # Also must be in float32!
401
+ if param.dtype != torch.float32:
402
+ name = name.replace("base_model", "model", 1)
403
+ layer_number = re.search(r"\.[\d]{1,}\.", name).group(0)
404
+ name = name.replace(layer_number, f"[{layer_number[1:-1]}].")
405
+ name = name.replace(".weight", "", 1)
406
+ exec(f"{name}.to(torch.float32)")
407
+ pass
408
+ else:
409
+ param.requires_grad_(False)
410
+ pass
411
+ pass
412
+
413
+ # Gradient checkpointing!
414
+ if use_gradient_checkpointing == "unsloth":
415
+
416
+ # Saves VRAM!
417
+ original_model = model
418
+ while hasattr(original_model, "model"):
419
+ original_model._offloaded_gradient_checkpointing = True
420
+ original_model = original_model.model
421
+ pass
422
+ original_model._offloaded_gradient_checkpointing = True
423
+
424
+ model.gradient_checkpointing_enable()
425
+
426
+ elif use_gradient_checkpointing == True:
427
+ model.gradient_checkpointing_enable()
428
+ pass
429
+
430
+ # If use_reentrant = True which is the Pytorch default, we just make the input requires_grad.
431
+ if use_reentrant:
432
+ if hasattr(model, "enable_input_require_grads"):
433
+ model.enable_input_require_grads()
434
+ else:
435
+ def make_inputs_require_grad(module, input, output):
436
+ output.requires_grad_(True)
437
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
438
+
439
+ return model
440
+ pass
441
+
442
+
443
+ def patch_tokenizer(model, tokenizer):
444
+ """
445
+ Phi3's pad_token isn't set. We set it to <|placeholder...
446
+ Llama-3 is <|reserved...
447
+ Llama-2 is <unk>
448
+ Check if pad_token is not the same as eos_token otherwise the loss will ignore it!!
449
+ Fixes https://github.com/unslothai/unsloth/issues/5
450
+ """
451
+ possible_reserved_tokens = (
452
+ "<|finetune_right_pad_id|>", # Llama-3.1
453
+ "<pad>", # Mistral Nemo
454
+ "<|reserved", # Llama-3
455
+ "<|placeholder", # Phi-3
456
+ "[control", # Mistral type models
457
+ )
458
+ joiner = "\1\0=+=\0\1"
459
+ number_repetitions = 3 - 1 # Number of reserved tokens needed
460
+
461
+ if model is not None:
462
+ model.config.update({"unsloth_version" : __version__})
463
+
464
+ bad_pad_token = False
465
+ if hasattr(tokenizer, "pad_token") and tokenizer.pad_token is not None:
466
+ # Check if pad_token is not the same as eos_token otherwise the loss will ignore it!!
467
+ bad_pad_token = tokenizer.eos_token == tokenizer.pad_token
468
+ elif hasattr(tokenizer, "pad_token") and tokenizer.pad_token is None:
469
+ bad_pad_token = True
470
+ else:
471
+ bad_pad_token = False
472
+ pass
473
+
474
+ if bad_pad_token:
475
+ # Find a better pad token
476
+ added_tokens = [str(x) for x in tokenizer.added_tokens_decoder.values()]
477
+ all_added_tokens = joiner.join(added_tokens[::-1])
478
+ all_added_tokens += joiner
479
+
480
+ final_pad_token = None
481
+ final_good_match = False
482
+
483
+ for possible_reserved_token in possible_reserved_tokens:
484
+ possible_reserved_token = re.escape(possible_reserved_token)
485
+ found = re.finditer(f"{possible_reserved_token}", all_added_tokens)
486
+ first_match = None
487
+ good_match = False
488
+ for j, x in enumerate(found):
489
+ if j == 0: first_match = x
490
+ if j >= number_repetitions:
491
+ good_match = True
492
+ break
493
+ pass
494
+ pass
495
+
496
+ if first_match is None: continue
497
+
498
+ # If it ends with |> or > etc, then set it as a good pad token!
499
+ start = first_match.span(0)[0]
500
+ possible_pad_token = first_match.group(0)
501
+ end = all_added_tokens.find(joiner, start)
502
+ first_match = all_added_tokens[start:end]
503
+
504
+ if first_match is not None:
505
+ good_match = possible_pad_token.endswith((">", "|>", "]", ")"))
506
+ pass
507
+ possible_pad_token = first_match
508
+
509
+ # Replace current pad token if another exact match is found
510
+ if not final_good_match and good_match:
511
+ final_good_match = True
512
+ final_pad_token = possible_pad_token
513
+ break
514
+ else:
515
+ final_good_match = False
516
+ final_pad_token = possible_pad_token
517
+ pass
518
+ pass
519
+ possible_pad_token = final_pad_token
520
+
521
+ # Try unk_token
522
+ if possible_pad_token is None and hasattr(tokenizer, "unk_token"):
523
+ possible_pad_token = tokenizer.unk_token
524
+ pass
525
+
526
+ # Check pad token's id must be less than vocab size
527
+ if possible_pad_token is not None:
528
+ check_pad_token = tokenizer(possible_pad_token, add_special_tokens = False).input_ids
529
+ if len(check_pad_token) != 1:
530
+ possible_pad_token = None
531
+ if model is not None and check_pad_token[0] >= model.config.vocab_size:
532
+ possible_pad_token = None
533
+ pass
534
+
535
+ if possible_pad_token is None:
536
+ # Failure to find a good replacement!! We shall manually add one!
537
+ new_pad_token = "<|PAD_TOKEN|>"
538
+ while new_pad_token in tokenizer.get_vocab():
539
+ new_pad_token = f"<{new_pad_token}>"
540
+ pass
541
+ possible_pad_token = new_pad_token
542
+ pass
543
+
544
+ name = model.config._name_or_path if model is not None else "Model"
545
+ logger.warning_once(
546
+ f"{name} does not have a padding token! Will use pad_token = {possible_pad_token}."
547
+ )
548
+
549
+ # Edit pad_token
550
+ tokenizer.add_special_tokens({"pad_token" : possible_pad_token})
551
+ tokenizer.pad_token = possible_pad_token
552
+ if model is not None:
553
+ model.config.update({"pad_token_id" : tokenizer.pad_token_id})
554
+ if getattr(model, "generation_config") is not None:
555
+ model.generation_config.update(pad_token_id = tokenizer.pad_token_id)
556
+ else:
557
+ if model is not None:
558
+ if model.config.pad_token_id is None:
559
+ model.config.update({"pad_token_id" : tokenizer.pad_token_id})
560
+ if getattr(model, "generation_config") is not None:
561
+ model.generation_config.update(pad_token_id = tokenizer.pad_token_id)
562
+ pass
563
+ pass
564
+
565
+ if model is not None:
566
+ if getattr(model, "generation_config") is not None:
567
+ model.generation_config.update(max_length = model.config.max_position_embeddings)
568
+
569
+ return model, tokenizer
570
+ pass
571
+
572
+
573
+ # =============================================
574
+ # Weirdly LoraLayer.update_layer downcasts PEFT layers to float16??
575
+ # For mixed precision, we need it to be in float32 not float16.
576
+ from peft import __version__ as peft_version
577
+ if Version(peft_version) < Version("0.12.0"):
578
+ from peft.tuners.lora.layer import LoraLayer
579
+ try:
580
+ source = inspect.getsource(LoraLayer.update_layer)
581
+ text = "if weight is not None:\n"
582
+ start = source.find(text) + len(text)
583
+ end = source.find("self.to(weight.device)", start)
584
+ spaces = re.findall(r"^([ ]{1,})break", source, flags = re.MULTILINE)[0]
585
+ source = source.replace(source[start : end], spaces)
586
+ spaces = len(re.match(r"[\s]{1,}", source).group(0))
587
+ lines = source.split("\n")
588
+ source = "\n".join(x[spaces:] for x in lines)
589
+ source = re.sub("([^\.])nn\.", r"\1torch.nn.", source)
590
+ source = source.replace("def update_layer", "def LoraLayer_update_layer")
591
+ exec(source, globals())
592
+
593
+ # Fix up incorrect downcasting of LoRA weights
594
+ from peft.tuners.lora.layer import LoraLayer
595
+ LoraLayer.update_layer = LoraLayer_update_layer
596
+ from peft.tuners.lora import LoraLayer
597
+ LoraLayer.update_layer = LoraLayer_update_layer
598
+ except:
599
+ logger.warning_once(
600
+ "Unsloth unsuccessfully patched LoraLayer.update_layer. Please file a bug report.\n"\
601
+ "Luckily, your training run will still work in the meantime!"
602
+ )
603
+ pass
604
+ pass
605
+ # =============================================
606
+
607
+ import psutil
608
+ def _get_statistics(statistics = None, force_download = True):
609
+ # We log some basic stats about which environment is being used.
610
+ # We simply download a README.md file from HF - all data is made public.
611
+ # This is simply so we can check if some envs are broken or not.
612
+ # You can disable this by commenting the below out
613
+ try:
614
+ n_cpus = psutil.cpu_count(logical = False)
615
+ keynames = "\n" + "\n".join(os.environ.keys())
616
+ if statistics is not None: pass
617
+ elif "\nCOLAB_" in keynames and n_cpus == 1: statistics = "colab"
618
+ elif "\nCOLAB_" in keynames: statistics = "colabpro"
619
+ elif "\nKAGGLE_" in keynames: statistics = "kaggle"
620
+ elif "\nRUNPOD_" in keynames: statistics = "runpod"
621
+ elif "\nAWS_" in keynames: statistics = "aws"
622
+ elif "\nAZURE_" in keynames: statistics = "azure"
623
+ # elif "\nK_" in keynames or "\nFUNCTION_" in keynames: statistics = "gcp"
624
+ elif "\nINVOCATION_ID" in keynames: statistics = "lambda"
625
+ # else: statistics = "other"
626
+ else:
627
+ def try_vllm_check():
628
+ vendor_files = (
629
+ "/sys/class/dmi/id/product_version",
630
+ "/sys/class/dmi/id/bios_vendor",
631
+ "/sys/class/dmi/id/product_name",
632
+ "/sys/class/dmi/id/chassis_asset_tag",
633
+ "/sys/class/dmi/id/sys_vendor",
634
+ )
635
+ from pathlib import Path
636
+ for vendor_file in vendor_files:
637
+ path = Path(vendor_file)
638
+ if path.is_file():
639
+ file_content = path.read_text().lower()
640
+ if "amazon" in file_content: return "aws"
641
+ elif "microsoft corporation" in file_content: return "azure"
642
+ elif "google" in file_content: return "gcp"
643
+ return "other"
644
+ pass
645
+ try: statistics = try_vllm_check()
646
+ except: statistics = "other"
647
+ pass
648
+ if statistics is not None:
649
+ from transformers import AutoModelForCausalLM
650
+ stats_model = AutoModelForCausalLM.from_pretrained(
651
+ f"unslothai/{statistics}",
652
+ force_download = force_download,
653
+ )
654
+ del stats_model
655
+ pass
656
+ except:
657
+ pass
658
+ pass
659
+
660
+
661
+ def get_statistics():
662
+ # We log some basic stats about which environment is being used.
663
+ # We simply download a README.md file from HF - all data is made public.
664
+ # This is simply so we can check if some envs are broken or not.
665
+ # You can disable this by commenting the below out
666
+ from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled
667
+ disabled = False
668
+ if not are_progress_bars_disabled():
669
+ disable_progress_bars()
670
+ disabled = True
671
+ pass
672
+ _get_statistics(None)
673
+ _get_statistics("repeat", force_download = False)
674
+ try:
675
+ vram = torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 / 1024
676
+ if vram <= 8 : vram = 8
677
+ elif vram <= 16: vram = 16
678
+ elif vram <= 20: vram = 20
679
+ elif vram <= 24: vram = 24
680
+ elif vram <= 40: vram = 40
681
+ elif vram <= 48: vram = 48
682
+ elif vram <= 80: vram = 80
683
+ else: vram = 96
684
+ _get_statistics(f"vram-{vram}")
685
+ except:
686
+ pass
687
+ pass
688
+ try:
689
+ devices = torch.cuda.device_count()
690
+ _get_statistics(f"{devices if devices <= 8 else 9}")
691
+ except:
692
+ pass
693
+ if disabled: enable_progress_bars()
694
+ pass
695
+
696
+
697
+ def _calculate_n_gradient_checkpoints(
698
+ n_layers : int,
699
+ method : Optional[Union[str, int]] = "sqrt",
700
+ ) -> List[int]:
701
+ assert(type(n_layers) is int and n_layers > 0)
702
+
703
+ if method is None: method = "sqrt"
704
+
705
+ if method == "sqrt":
706
+ n_checkpoints = int(n_layers**0.5)
707
+ elif type(method) is int and method > 0:
708
+ n_checkpoints = int(np.ceil(n_layers / method))
709
+ else:
710
+ raise ValueError("method must be 'sqrt' or an int >0 and <= n_layers.")
711
+
712
+ size = n_layers // n_checkpoints
713
+ sizes = np.full(n_checkpoints, size, dtype = int)
714
+ leftovers = n_layers % n_checkpoints
715
+ # We append leftovers from the right
716
+ for k in range(leftovers):
717
+ sizes[n_checkpoints-1-k] += 1
718
+ boundaries = np.hstack((0, np.cumsum(sizes)))
719
+ boundaries = boundaries.tolist()
720
+ return boundaries
721
+ pass
722
+
723
+
724
+ def calculate_n_gradient_checkpoints(
725
+ n_layers : int,
726
+ layers_per_checkpoint : Optional[Union[str, int]] = "sqrt",
727
+ ) -> List[int]:
728
+ assert(type(n_layers) is int and n_layers > 0)
729
+
730
+ if layers_per_checkpoint is None or layers_per_checkpoint == 1:
731
+ return None
732
+
733
+ boundaries = _calculate_n_gradient_checkpoints(n_layers, layers_per_checkpoint)
734
+
735
+ assert(boundaries[0] == 0 and boundaries[-1] == n_layers)
736
+ assert(min(boundaries) == 0 and max(boundaries) == n_layers)
737
+ assert(np.diff(boundaries).min() >= 0)
738
+ return boundaries
739
+ pass
740
+
741
+
742
+ def prepare_n_gradient_checkpoints(
743
+ model : Any,
744
+ layers_per_checkpoint : Optional[Union[str, int]] = "sqrt",
745
+ use_reentrant : Optional[bool] = True,
746
+ ) -> None:
747
+ """
748
+ Calculates where to place the gradient checkpoints given n_layers.
749
+
750
+ Args:
751
+ model: Any LlamaModel with layers.
752
+ layers_per_checkpoint (`Union[str, int]`, *optional*):
753
+ Can either be `sqrt` or an integer for how many layers per checkpoint you want.
754
+ The more, the less memory usage, but can be slower. Default is `sqrt`.
755
+ Choose 1 for Pytorch gradient checkpointing. 2 to wrap 2 layers in 1 module etc.
756
+ use_reentrant (`bool`, *optional*):
757
+ https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L354
758
+ Optimal gradient checkpointing algorithm `use_reentrant=False` which will
759
+ be the default in future Pytorch versions doesn't seem to work??
760
+ """
761
+ _model = None
762
+ if hasattr(model, "layers"):
763
+ _model = model
764
+ elif hasattr(model, "model"):
765
+ if hasattr(model.model, "layers"):
766
+ _model = model.model
767
+ if _model is None:
768
+ raise TypeError("`model` or `model.model` does not have attribute `layers`. Are you sure this is a model?")
769
+ pass
770
+
771
+ if use_reentrant is False:
772
+ use_reentrant = True
773
+ pass
774
+
775
+ n_layers = len(_model.layers)
776
+ boundaries = calculate_n_gradient_checkpoints(n_layers, layers_per_checkpoint)
777
+ _model._gradient_checkpointing_boundaries = boundaries
778
+ _model._gradient_checkpointing_use_reentrant = use_reentrant
779
+ pass
780
+
781
+
782
+ class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function):
783
+ """
784
+ Saves VRAM by smartly offloading to RAM.
785
+ Tiny hit to performance, since we mask the movement via non blocking calls.
786
+ """
787
+ @staticmethod
788
+ @torch_amp_custom_fwd
789
+ def forward(ctx, forward_function, hidden_states, *args):
790
+ saved_hidden_states = hidden_states.to("cpu", non_blocking = True)
791
+ with torch.no_grad():
792
+ output = forward_function(hidden_states, *args)
793
+ ctx.save_for_backward(saved_hidden_states)
794
+ ctx.forward_function = forward_function
795
+ ctx.args = args
796
+ return output
797
+ pass
798
+
799
+ @staticmethod
800
+ @torch_amp_custom_bwd
801
+ def backward(ctx, dY):
802
+ (hidden_states,) = ctx.saved_tensors
803
+ hidden_states = hidden_states.to("cuda:0", non_blocking = True).detach()
804
+ hidden_states.requires_grad_(True)
805
+ with torch.enable_grad():
806
+ (output,) = ctx.forward_function(hidden_states, *ctx.args)
807
+ torch.autograd.backward(output, dY)
808
+ return (None, hidden_states.grad,) + (None,)*len(ctx.args)
809
+ pass
810
+ pass
811
+
812
+
813
+ @torch._disable_dynamo
814
+ def unsloth_offloaded_gradient_checkpoint(function, *args, use_reentrant = None, **kwargs):
815
+ return Unsloth_Offloaded_Gradient_Checkpointer.apply(function, *args)
816
+ pass
817
+
818
+
819
+ import torch.utils
820
+ old_checkpoint = torch.utils.checkpoint
821
+ def patch_gradient_checkpointing():
822
+ torch.utils.checkpoint = unsloth_offloaded_gradient_checkpoint
823
+ pass
824
+
825
+ def unpatch_gradient_checkpointing():
826
+ torch.utils.checkpoint = old_checkpoint
827
+ pass
828
+
829
+
830
+ # =============================================
831
+ # Fixes Bitsandbytes to remove missing warnings
832
+ from transformers.utils.quantization_config import BitsAndBytesConfig, QuantizationMethod
833
+ from inspect import getsource
834
+ from accelerate.utils.dataclasses import DistributedType
835
+ BitsAndBytesConfig__init__ = getsource(BitsAndBytesConfig.__init__)
836
+ BitsAndBytesConfig__init__ = re.sub(
837
+ r"if[\s]{1,}kwargs\:[\s]{1,}.+?\n",
838
+ "",
839
+ BitsAndBytesConfig__init__,
840
+ flags = re.MULTILINE,
841
+ )
842
+ BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.split("\n")
843
+ length_spaces = len(re.match(r"[\s]{1,}", BitsAndBytesConfig__init__[0]).group(0))
844
+ BitsAndBytesConfig__init__ = "\n".join(x[length_spaces:] for x in BitsAndBytesConfig__init__)
845
+ BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.replace(
846
+ "__init__",
847
+ "_BitsAndBytesConfig__init__",
848
+ )
849
+
850
+ def _prepare_backend(
851
+ self, cpu: bool = False, sagemaker_dp = False, backend: str = None,
852
+ ) -> tuple[str, DistributedType]:
853
+ return None, DistributedType.NO
854
+ pass
855
+ import accelerate.state
856
+ accelerate.state.PartialState._prepare_backend = _prepare_backend
857
+
858
+ import accelerate.accelerator
859
+ prepare = inspect.getsource(accelerate.accelerator.Accelerator.prepare)
860
+ prepare = prepare.split("\n")
861
+ spaces = prepare[0].find("def")
862
+ prepare = "\n".join(x[spaces:] for x in prepare)
863
+ x = "for obj in args:"
864
+ s = " "*spaces
865
+ prepare = prepare.replace(x, f'self.state.distributed_type = DistributedType.NO\n{s}{x}', 1)
866
+ exec(prepare, globals())
867
+ accelerate.accelerator.Accelerator.prepare = prepare
868
+
869
+ exec(BitsAndBytesConfig__init__, globals())
870
+
871
+ import transformers.utils.quantization_config
872
+ transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = _BitsAndBytesConfig__init__
873
+ # =============================================
874
+
875
+ # Offloading to disk for modules (lm_head, embed_tokens)
876
+ import pickle
877
+
878
+ def offload_to_disk(W, model, name, temporary_location : str = "_unsloth_temporary_saved_buffers"):
879
+ file_location = os.path.join(temporary_location, model.config._name_or_path)
880
+ if not os.path.exists(file_location):
881
+ os.makedirs(file_location)
882
+ pass
883
+
884
+ filename = os.path.join(file_location, f"{name}.pt")
885
+ W = W.weight if hasattr(W, "weight") else W
886
+ torch.save(W, filename, pickle_module = pickle, pickle_protocol = pickle.HIGHEST_PROTOCOL,)
887
+ offloaded_W = torch.load(filename, map_location = "cpu", mmap = True)
888
+ offloaded_W._offloaded_file_location = filename
889
+ return offloaded_W
890
+ pass
891
+
892
+
893
+ def offload_input_embeddings(model, temporary_location : str = "_unsloth_temporary_saved_buffers"):
894
+ offloaded_W = offload_to_disk(model.get_input_embeddings(), model, "input_embeddings", temporary_location)
895
+ new_input_embeddings = torch.nn.Embedding.from_pretrained(offloaded_W)
896
+ new_input_embeddings._offloaded_file_location = offloaded_W._offloaded_file_location
897
+ model.set_input_embeddings(new_input_embeddings)
898
+ return
899
+ pass
900
+
901
+
902
+ def offload_output_embeddings(model, temporary_location : str = "_unsloth_temporary_saved_buffers"):
903
+ offloaded_W = offload_to_disk(model.get_output_embeddings(), model, "output_embeddings", temporary_location)
904
+
905
+ new_output_embeddings = torch.nn.Linear(1, 1, bias = None)
906
+ del new_output_embeddings.weight
907
+ new_output_embeddings.weight = offloaded_W
908
+ new_output_embeddings.in_features = offloaded_W.shape[1]
909
+ new_output_embeddings.out_features = offloaded_W.shape[0]
910
+
911
+ new_output_embeddings._offloaded_file_location = offloaded_W._offloaded_file_location
912
+ model.set_output_embeddings(new_output_embeddings)
913
+ return
914
+ pass
915
+
916
+
917
+ # Fixes a weird Torch 2.3 bug which says T4s have bfloat16
918
+ def is_bfloat16_supported():
919
+ return SUPPORTS_BFLOAT16
920
+ pass
921
+
922
+
923
+ # Patches models to add RoPE Scaling
924
+ def patch_linear_scaling(
925
+ model_name = "gemma2",
926
+ rope_module = None,
927
+ scaled_rope_module = None,
928
+ attention_module = None,
929
+ ):
930
+ assert(rope_module is not None and scaled_rope_module is not None)
931
+ assert(attention_module is not None)
932
+
933
+ rope_name = rope_module.__name__
934
+ scaled_rope_name = scaled_rope_module.__name__
935
+ model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
936
+ exec_code = \
937
+ f"import torch.nn as nn\n"\
938
+ f"from typing import Union, Optional, List, Any, Callable, Tuple\n"\
939
+ f"from {model_filepath} import logger, "\
940
+ f"{model_name.title()}Attention, {model_name.title()}Config"
941
+
942
+ try:
943
+ function = inspect.getsource(attention_module.__init__)
944
+ except:
945
+ # Most likely already patched!
946
+ return None, None
947
+ where = function.find("def")
948
+ function = function.split("\n")
949
+ function = "\n".join(x[where:] for x in function)
950
+ init_name = f"{model_name.title()}Attention__init__"
951
+ function = function.replace("def __init__", f"def {init_name}")
952
+ function = function.replace(
953
+ "super().__init__()",
954
+ f"super({model_name.title()}Attention, self).__init__()",
955
+ )
956
+ fix_rope_function = """
957
+ if getattr(self.config, "rope_scaling", None) is None:
958
+ self.rotary_emb = {rope_function}(
959
+ dim = self.head_dim,
960
+ max_position_embeddings=self.max_position_embeddings,
961
+ base=self.rope_theta,
962
+ )
963
+ else:
964
+ scaling_type = self.config.rope_scaling["type"]
965
+ scaling_factor = self.config.rope_scaling["factor"]
966
+ if scaling_type == "linear":
967
+ self.rotary_emb = {scaled_rope_function}(
968
+ dim = self.head_dim,
969
+ max_position_embeddings=self.max_position_embeddings,
970
+ scaling_factor=scaling_factor,
971
+ base=self.rope_theta,
972
+ )
973
+ else:
974
+ raise ValueError(f"Unknown RoPE scaling type {{scaling_type}}")
975
+ pass
976
+ """
977
+ fix_rope_function = fix_rope_function.format(
978
+ rope_function = rope_module.__name__,
979
+ scaled_rope_function = scaled_rope_module.__name__,
980
+ )
981
+ rotary_emb = re.findall(
982
+ "self.rotary_emb = .+?\)", function,
983
+ flags = re.DOTALL | re.MULTILINE,
984
+ )
985
+ if len(rotary_emb) == 0: return None, function
986
+ rotary_emb = rotary_emb[0]
987
+ function = function.replace(rotary_emb, fix_rope_function, 1)
988
+ function = exec_code + "\n\n" + function
989
+ return init_name, function
990
+ pass
991
+
992
+
993
+ # Patches for Llama-3 LlamaExtendedRotaryEmbedding
994
+ def patch_llama_rope_scaling(
995
+ model_name = "llama",
996
+ rope_module = None,
997
+ scaled_rope_module = None,
998
+ extended_rope_module = None,
999
+ attention_module = None,
1000
+ longrope_module = None,
1001
+ ):
1002
+ assert(\
1003
+ rope_module is not None and \
1004
+ scaled_rope_module is not None and \
1005
+ extended_rope_module is not None
1006
+ )
1007
+ assert(attention_module is not None)
1008
+
1009
+ rope_name = rope_module.__name__
1010
+ scaled_rope_name = scaled_rope_module.__name__
1011
+ model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
1012
+ exec_code = \
1013
+ f"import torch.nn as nn\n"\
1014
+ f"from typing import Union, Optional, List, Any, Callable, Tuple\n"\
1015
+ f"from {model_filepath} import logger, "\
1016
+ f"{model_name.title()}Attention, {model_name.title()}Config"
1017
+
1018
+ try:
1019
+ function = inspect.getsource(attention_module.__init__)
1020
+ except:
1021
+ # Most likely already patched!
1022
+ return None, None
1023
+ where = function.find("def")
1024
+ function = function.split("\n")
1025
+ function = "\n".join(x[where:] for x in function)
1026
+ init_name = f"{model_name.title()}Attention__init__"
1027
+ function = function.replace("def __init__", f"def {init_name}")
1028
+ function = function.replace(
1029
+ "super().__init__()",
1030
+ f"super({model_name.title()}Attention, self).__init__()",
1031
+ )
1032
+ fix_rope_function = """
1033
+ if getattr(self.config, "rope_scaling", None) is None:
1034
+ self.rotary_emb = {rope_function}(
1035
+ dim = self.head_dim,
1036
+ max_position_embeddings=self.max_position_embeddings,
1037
+ base=self.rope_theta,
1038
+ )
1039
+ else:
1040
+ scaling_type1 = self.config.rope_scaling.get("type", None)
1041
+ scaling_type2 = self.config.rope_scaling.get("rope_type", None)
1042
+ scaling_type = scaling_type1 if scaling_type1 is not None else scaling_type2
1043
+ scaling_factor = self.config.rope_scaling.get("factor")
1044
+
1045
+ if scaling_type == "linear":
1046
+ self.rotary_emb = {scaled_rope_function}(
1047
+ dim = self.head_dim,
1048
+ max_position_embeddings=self.max_position_embeddings,
1049
+ scaling_factor=scaling_factor,
1050
+ base=self.rope_theta,
1051
+ )
1052
+ elif scaling_type == "llama3":
1053
+ self.rotary_emb = {extended_rope_function}(
1054
+ dim = self.head_dim,
1055
+ max_position_embeddings=self.max_position_embeddings,
1056
+ base=self.rope_theta,
1057
+ )
1058
+ elif scaling_type == "longrope":
1059
+ self.rotary_emb = {longrope_rope_function}(
1060
+ dim = self.head_dim,
1061
+ max_position_embeddings = self.max_position_embeddings,
1062
+ original_max_position_embeddings = self.config.original_max_position_embeddings,
1063
+ base = self.rope_theta,
1064
+ short_factor = self.config.rope_scaling['short_factor'],
1065
+ long_factor = self.config.rope_scaling['long_factor' ],
1066
+ )
1067
+ else:
1068
+ raise ValueError(f"Unknown RoPE scaling type {{scaling_type}}")
1069
+ pass
1070
+ """
1071
+
1072
+ fix_rope_function = fix_rope_function.format(
1073
+ rope_function = rope_module.__name__,
1074
+ scaled_rope_function = scaled_rope_module.__name__,
1075
+ extended_rope_function = extended_rope_module.__name__,
1076
+ longrope_rope_function = \
1077
+ (longrope_module if longrope_module is not None else rope_module).__name__
1078
+ )
1079
+ rotary_emb = re.findall(
1080
+ "self.rotary_emb = .+?\)", function,
1081
+ flags = re.DOTALL | re.MULTILINE,
1082
+ )
1083
+ if len(rotary_emb) == 0: return None, function
1084
+ rotary_emb = rotary_emb[0]
1085
+ function = function.replace(rotary_emb, fix_rope_function, 1)
1086
+ function = exec_code + "\n\n" + function
1087
+ return init_name, function
1088
+ pass
1089
+
1090
+
1091
+ def check_nvidia():
1092
+ # Unsloth doesn't work yet on AMD devices - we're working on it!
1093
+ output = np.array([0,])
1094
+ try:
1095
+ output = subprocess.check_output("nvidia-smi --query-gpu=memory.used --format=csv", shell = True)
1096
+ output = re.findall(rb'([\d]{1,})[\s]{1,}M', output)
1097
+ output = np.array([int(x.decode('utf-8'))/1024 for x in output])
1098
+ except:
1099
+ if not torch.cuda.is_available():
1100
+ raise RuntimeError("Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!")
1101
+ return output
1102
+ pass
1103
+ PRE_CHECK = check_nvidia()
1104
+
1105
+
1106
+ def create_boolean_mask(n = 4096, sliding_window = 2048):
1107
+ # Creates a boolean mask for attention
1108
+ mask = torch.ones(n, n, dtype = torch.bool)
1109
+ if sliding_window == 0:
1110
+ return torch.triu(mask, diagonal = 1, out = mask)
1111
+ pass
1112
+ torch.triu(mask, diagonal = 0, out = mask)
1113
+ torch.triu(mask.T, diagonal = -sliding_window, out = mask.T)
1114
+ mask = mask.T
1115
+ torch.logical_not(mask, out = mask)
1116
+ return mask
1117
+ pass
1118
+
1119
+
1120
+ def test_mask_creation():
1121
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
1122
+ for n in range(2, 23):
1123
+ for s in range(1, 23):
1124
+ correct_mask = AttentionMaskConverter(
1125
+ is_causal = True,
1126
+ sliding_window = s,
1127
+ ).to_causal_4d(1, n, n, dtype = torch.float16,).squeeze(0).squeeze(0)
1128
+ correct_mask = (correct_mask == correct_mask.min())
1129
+ our_mask = create_boolean_mask(n = n, sliding_window = s)
1130
+ assert(torch.all(correct_mask == our_mask))
1131
+ pass
1132
+ correct_mask = AttentionMaskConverter(
1133
+ is_causal = True,
1134
+ sliding_window = None,
1135
+ ).to_causal_4d(1, n, n, dtype = torch.float16,).squeeze(0).squeeze(0)
1136
+ correct_mask = (correct_mask == correct_mask.min())
1137
+ our_mask = create_boolean_mask(n = n, sliding_window = 0)
1138
+ assert(torch.all(correct_mask == our_mask))
1139
+ pass
1140
+ pass
unsloth-main/unsloth-main/unsloth/models/cohere.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .llama import *
16
+ from ._utils import __version__
17
+ try:
18
+ from transformers.models.cohere.modeling_cohere import (
19
+ CohereAttention,
20
+ CohereDecoderLayer,
21
+ CohereModel,
22
+ CohereForCausalLM,
23
+ CohereRotaryEmbedding,
24
+ apply_rotary_pos_emb,
25
+ repeat_kv,
26
+ )
27
+ except:
28
+ from packaging.version import Version
29
+ transformers_version = Version(transformers_version)
30
+ if not transformers_version >= Version("4.42"):
31
+ raise ImportError(
32
+ f"Unsloth: Your transformers version of {transformers_version} does not support Cohere.\n"\
33
+ f"The minimum required version is 4.42.3.\n"\
34
+ f'Try `pip install --upgrade "transformers>=4.42.3"`\n'\
35
+ f"to obtain the latest transformers build, then restart this session."\
36
+ )
37
+ pass
38
+ pass
39
+
40
+ from transformers.modeling_attn_mask_utils import (
41
+ _prepare_4d_causal_attention_mask_for_sdpa,
42
+ )
43
+ # For Pytorch 2.1.1
44
+ try:
45
+ from transformers.models.cohere.modeling_cohere import (
46
+ CohereSdpaAttention,
47
+ CohereFlashAttention2,
48
+ )
49
+ except:
50
+ CohereSdpaAttention = CohereAttention
51
+ CohereFlashAttention2 = CohereAttention
52
+ pass
53
+
54
+
55
+ def fast_layernorm_inference(self, X, out_weight = None):
56
+ XX = X.to(torch.float32, copy = True)
57
+ XX -= X.mean(-1, keepdim = True)
58
+ variance = XX.square().mean(-1, keepdim = True)
59
+ variance += self.variance_epsilon
60
+ XX *= variance.rsqrt_()
61
+ out_weight[:] = self.weight
62
+ XX *= out_weight
63
+ return XX.to(X.dtype)
64
+ pass
65
+
66
+
67
+ # QK norm in Cohere
68
+ def CohereAttention_fast_forward(
69
+ self,
70
+ hidden_states: torch.Tensor,
71
+ causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
72
+ attention_mask: Optional[torch.Tensor] = None,
73
+ position_ids: Optional[torch.LongTensor] = None,
74
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
75
+ output_attentions: bool = False,
76
+ use_cache: bool = False,
77
+ padding_mask: Optional[torch.LongTensor] = None,
78
+ *args, **kwargs,
79
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
80
+
81
+ # Clear inference
82
+ if hasattr(self, "paged_attention"):
83
+ del self.paged_attention_K
84
+ del self.paged_attention_V
85
+ del self.paged_attention
86
+ del self.temp_QA
87
+ del self.temp_KV
88
+ del self.RH_Q
89
+ del self.attention
90
+ del self.q_norm_out_weight
91
+ del self.k_norm_out_weight
92
+ pass
93
+
94
+ bsz, q_len, _ = hidden_states.size()
95
+
96
+ n_heads = self.num_heads
97
+ n_groups = self.num_key_value_groups
98
+ n_kv_heads = self.num_key_value_heads
99
+ head_dim = self.head_dim
100
+ assert(n_kv_heads * n_groups == n_heads)
101
+
102
+ Q, K, V = self.apply_qkv(self, hidden_states)
103
+ Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
104
+ K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
105
+ V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
106
+ if self.use_qk_norm:
107
+ Q = fast_layernorm_compiled(self.q_norm, Q)
108
+ K = fast_layernorm_compiled(self.k_norm, K)
109
+ pass
110
+
111
+ kv_seq_len = K.shape[-2]
112
+ if past_key_value is not None:
113
+ kv_seq_len += past_key_value[0].shape[-2]
114
+
115
+ if position_ids is None:
116
+ cos = self.rotary_emb.cos_cached
117
+ sin = self.rotary_emb.sin_cached
118
+ Q, K = fast_rope_embedding(Q, K, cos, sin)
119
+ else:
120
+ cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
121
+ Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
122
+ pass
123
+
124
+ if past_key_value is not None:
125
+ K = torch.cat([past_key_value[0], K], dim = 2)
126
+ V = torch.cat([past_key_value[1], V], dim = 2)
127
+ pass
128
+ past_key_value = (K, V) if use_cache else None
129
+
130
+ # Attention module
131
+ if (not HAS_FLASH_ATTENTION and attention_mask is None):
132
+ # Xformers memory efficient attention
133
+ # Also has Flash Attention v2 dispatching
134
+ Q = Q.transpose(1, 2)
135
+ K = K.transpose(1, 2)
136
+ V = V.transpose(1, 2)
137
+
138
+ # Group query attention
139
+ if n_groups != 1:
140
+ K = K .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
141
+ V = V .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
142
+ K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
143
+ V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
144
+ if hidden_states.requires_grad:
145
+ K = K.reshape(bsz, kv_seq_len, n_heads, head_dim)
146
+ V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
147
+ else:
148
+ Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
149
+ pass
150
+ A = xformers_attention(Q, K, V, attn_bias = causal_mask)
151
+ A = A.view(bsz, q_len, n_heads, head_dim)
152
+
153
+ elif HAS_FLASH_ATTENTION and attention_mask is None:
154
+ Q = Q.transpose(1, 2)
155
+ K = K.transpose(1, 2)
156
+ V = V.transpose(1, 2)
157
+ A = flash_attn_func(Q, K, V, causal = True)
158
+ else:
159
+ # Grouped query attention
160
+ if n_groups != 1:
161
+ K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
162
+ V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
163
+ K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
164
+ V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
165
+ pass
166
+ # Must be contiguous or else results are False!
167
+ # https://github.com/pytorch/pytorch/issues/112577
168
+ Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
169
+ # Needs (batch_size, n_heads, seq_len, head_dim)
170
+ # is_casual and attention_mask must not be both set!
171
+ A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False)
172
+ # Go back to (batch_size, seq_len, n_heads, head_dim)
173
+ A = A.transpose(1, 2).contiguous()
174
+ pass
175
+ attn_output = A.reshape(bsz, q_len, n_heads*head_dim)
176
+ attn_output = self.apply_o(self, attn_output)
177
+ attn_weights = None
178
+ return attn_output, attn_weights, past_key_value
179
+ pass
180
+
181
+
182
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
183
+ def CohereDecoderLayer_fast_forward(
184
+ self,
185
+ hidden_states: torch.Tensor,
186
+ causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
187
+ attention_mask: Optional[torch.Tensor] = None,
188
+ position_ids: Optional[torch.LongTensor] = None,
189
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
190
+ output_attentions: Optional[bool] = False,
191
+ use_cache: Optional[bool] = False,
192
+ padding_mask: Optional[torch.LongTensor] = None,
193
+ *args, **kwargs,
194
+ ):
195
+ if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
196
+ out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0")
197
+
198
+ # Self Attention
199
+ residual = hidden_states
200
+ hidden_states = fast_layernorm_inference(self.input_layernorm, hidden_states, out_weight)
201
+ hidden_states_attention, self_attn_weights, present_key_value = self.self_attn(
202
+ hidden_states=hidden_states,
203
+ causal_mask=causal_mask,
204
+ attention_mask=attention_mask,
205
+ position_ids=position_ids,
206
+ past_key_value=past_key_value,
207
+ output_attentions=output_attentions,
208
+ use_cache=use_cache,
209
+ padding_mask=padding_mask,
210
+ )
211
+
212
+ # Fully Connected
213
+ hidden_states_mlp = fast_swiglu_inference(self.mlp, hidden_states)
214
+ residual += hidden_states_attention
215
+ residual += hidden_states_mlp
216
+ hidden_states = residual
217
+ else:
218
+ residual = hidden_states
219
+ hidden_states = fast_layernorm_compiled(self.input_layernorm, hidden_states)
220
+ hidden_states_attention, self_attn_weights, present_key_value = self.self_attn(
221
+ hidden_states=hidden_states,
222
+ causal_mask=causal_mask,
223
+ attention_mask=attention_mask,
224
+ position_ids=position_ids,
225
+ past_key_value=past_key_value,
226
+ output_attentions=output_attentions,
227
+ use_cache=use_cache,
228
+ padding_mask=padding_mask,
229
+ )
230
+
231
+ # Fully Connected
232
+ hidden_states_mlp = self.mlp(hidden_states)
233
+ hidden_states = residual + hidden_states_attention + hidden_states_mlp
234
+ pass
235
+
236
+ outputs = (hidden_states,)
237
+ if output_attentions: outputs += (self_attn_weights,)
238
+ if use_cache: outputs += (present_key_value,)
239
+ return outputs
240
+ pass
241
+
242
+
243
+ from math import sqrt as math_sqrt
244
+ KV_CACHE_INCREMENT = 256 # KV Cache update size
245
+ torch_nn_functional_softmax = torch.nn.functional.softmax
246
+ torch_matmul = torch.matmul
247
+
248
+ def CohereAttention_fast_forward_inference(
249
+ self,
250
+ hidden_states: torch.Tensor,
251
+ past_key_value: Optional[Tuple[torch.Tensor]],
252
+ position_ids,
253
+ do_prefill = False,
254
+ attention_mask = None,
255
+ ):
256
+ Xn = hidden_states
257
+ bsz, _, hd = hidden_states.size()
258
+ K1, V1 = past_key_value
259
+ dtype = Xn.dtype
260
+
261
+ n_heads = self.num_heads
262
+ n_groups = self.num_key_value_groups
263
+ n_kv_heads = self.num_key_value_heads
264
+ head_dim = self.head_dim
265
+ attention_size = n_heads*head_dim
266
+ # assert(n_kv_heads * n_groups == n_heads)
267
+ seq_len = K1.shape[-2]
268
+ kv_seq_len = seq_len + 1
269
+
270
+ # Prefill phase
271
+ # if not hasattr(self, "paged_attention"):
272
+ if do_prefill:
273
+ self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda:0")
274
+ self.paged_attention_K = self.paged_attention[:,0]
275
+ self.paged_attention_V = self.paged_attention[:,1]
276
+ self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
277
+ self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
278
+ self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0")
279
+ self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0")
280
+ self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
281
+
282
+ # Mistral Nemo 12b has weird dimensions
283
+ if attention_size != self.hidden_size:
284
+ self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0")
285
+ else:
286
+ self.temp_O = self.temp_QA[1][:,:,:self.hidden_size]
287
+ pass
288
+
289
+ self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
290
+ self.scalar = 1.0 / math_sqrt(self.head_dim)
291
+ self.half_head_dim = head_dim // 2
292
+ # Cohere has QK layernorms
293
+ if self.use_qk_norm:
294
+ self.q_norm_out_weight = torch.empty(self.q_norm.weight.shape, dtype = torch.float32, device = "cuda:0")
295
+ self.k_norm_out_weight = torch.empty(self.k_norm.weight.shape, dtype = torch.float32, device = "cuda:0")
296
+ else:
297
+ self.q_norm_out_weight = None
298
+ self.k_norm_out_weight = None
299
+ pass
300
+ elif kv_seq_len >= self.paged_attention.shape[0]:
301
+ self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim))
302
+ self.paged_attention_K = self.paged_attention[:,0]
303
+ self.paged_attention_V = self.paged_attention[:,1]
304
+ self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT))
305
+ pass
306
+
307
+ Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
308
+ Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
309
+ Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
310
+ Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
311
+ Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
312
+ Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
313
+ if self.use_qk_norm:
314
+ Q = fast_layernorm_inference(self.q_norm, Q, self.q_norm_out_weight)
315
+ K = fast_layernorm_inference(self.k_norm, K, self.k_norm_out_weight)
316
+ pass
317
+
318
+ # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
319
+ # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
320
+ cos, sin = self.rotary_emb.get_cached(kv_seq_len)
321
+ cos = cos[position_ids].unsqueeze(1)
322
+ sin = sin[position_ids].unsqueeze(1)
323
+ h = self.half_head_dim
324
+
325
+ RH_Q = self.RH_Q
326
+ RH_Q[:,:,:,:h] = Qn[:,:,:,h:]
327
+ RH_Q[:,:,:,h:] = Qn[:,:,:,:h]
328
+ torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
329
+ Qn *= cos
330
+ Qn.addcmul_(RH_Q, sin)
331
+
332
+ RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
333
+ RH_K[:,:,:,:h] = Kn[:,:,:,h:]
334
+ RH_K[:,:,:,h:] = Kn[:,:,:,:h]
335
+ torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
336
+ Kn *= cos
337
+ Kn.addcmul_(RH_K, sin)
338
+
339
+ # New KV cache
340
+ # Kn = torch.cat([K1, Kn], dim = 2)
341
+ # Vn = torch.cat([V1, Vn], dim = 2)
342
+ self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3)
343
+ self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3)
344
+ Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3)
345
+ Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3)
346
+
347
+ # Handle sliding windows
348
+ sliding_window = getattr(self.config, "sliding_window", None)
349
+ if sliding_window is not None and kv_seq_len > sliding_window:
350
+ # From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
351
+ slicing_tokens = 1 - sliding_window
352
+ Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
353
+ Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()
354
+ else:
355
+ Knn, Vnn = Kn, Vn
356
+ pass
357
+
358
+ # Grouped query attention
359
+ _, _, cached_len, _ = Knn.shape
360
+ if n_groups != 1:
361
+ Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
362
+ Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
363
+ Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
364
+ Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
365
+ pass
366
+ # else:
367
+ # Knn, Vnn = Knn, Vnn
368
+ # pass
369
+
370
+ # Attention
371
+ if bsz == 1:
372
+ Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
373
+ # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
374
+ A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
375
+ # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
376
+ A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
377
+ A = torch_matmul(A, Vnn, out = Qn)
378
+ else:
379
+ A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
380
+ pass
381
+ A = A.transpose(1, 2)
382
+ A = A.reshape(bsz, 1, attention_size)
383
+ A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
384
+ return A, (Kn, Vn)
385
+ pass
386
+
387
+
388
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
389
+ # @torch.inference_mode
390
+ def CohereModel_fast_forward_inference(
391
+ self,
392
+ input_ids,
393
+ past_key_values,
394
+ position_ids,
395
+ attention_mask = None,
396
+ ):
397
+ out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0")
398
+ input_ids = input_ids[:,:self.max_seq_length]
399
+ hidden_states = self.model.embed_tokens(input_ids)
400
+ hidden_states = hidden_states.to(self.config.torch_dtype)
401
+ bsz, q_len, hd = hidden_states.shape
402
+ seq_len = past_key_values[0][0].shape[-2]
403
+ if bsz != 1:
404
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
405
+ attention_mask,
406
+ (bsz, q_len),
407
+ hidden_states,
408
+ seq_len,
409
+ sliding_window = getattr(self.config, "sliding_window", None),
410
+ )
411
+ else:
412
+ attention_mask = None
413
+ pass
414
+
415
+ next_decoder_cache = []
416
+ for idx, decoder_layer in enumerate(self.model.layers):
417
+ residual = hidden_states
418
+ hidden_states = fast_layernorm_inference(decoder_layer.input_layernorm, hidden_states, out_weight)
419
+ hidden_states_attention, present_key_value = CohereAttention_fast_forward_inference(
420
+ decoder_layer.self_attn,
421
+ hidden_states = hidden_states,
422
+ past_key_value = past_key_values[idx],
423
+ position_ids = position_ids,
424
+ attention_mask = attention_mask,
425
+ do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
426
+ )
427
+
428
+ hidden_states_mlp = fast_swiglu_inference(self.mlp, hidden_states)
429
+ residual += hidden_states_attention
430
+ residual += hidden_states_mlp
431
+ hidden_states = residual
432
+
433
+ next_decoder_cache.append(present_key_value)
434
+ pass
435
+ hidden_states = fast_layernorm_inference(self.model.norm, hidden_states, out_weight)
436
+
437
+ return BaseModelOutputWithPast(
438
+ last_hidden_state = hidden_states,
439
+ past_key_values = next_decoder_cache,
440
+ hidden_states = [],
441
+ attentions = [],
442
+ )
443
+ pass
444
+
445
+
446
+ class FastCohereModel(FastLlamaModel):
447
+
448
+ @staticmethod
449
+ def pre_patch():
450
+ init_name, function = patch_linear_scaling(
451
+ model_name = "cohere",
452
+ rope_module = LlamaRotaryEmbedding,
453
+ scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
454
+ attention_module = CohereAttention,
455
+ )
456
+ if init_name is not None:
457
+ exec(function, globals())
458
+ CohereAttention.__init__ = eval(init_name)
459
+ pass
460
+ CohereAttention .forward = CohereAttention_fast_forward
461
+ CohereSdpaAttention .forward = CohereAttention_fast_forward
462
+ CohereFlashAttention2.forward = CohereAttention_fast_forward
463
+ CohereDecoderLayer .forward = CohereDecoderLayer_fast_forward
464
+ CohereModel .forward = LlamaModel_fast_forward
465
+ CohereForCausalLM .forward = CausalLM_fast_forward(CohereModel_fast_forward_inference)
466
+ PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward
467
+ fix_prepare_inputs_for_generation(CohereForCausalLM)
468
+
469
+ import transformers.models.cohere.modeling_cohere
470
+ transformers.models.cohere.modeling_cohere.CohereRotaryEmbedding = LlamaRotaryEmbedding
471
+ return
472
+ pass
473
+ pass
unsloth-main/unsloth-main/unsloth/models/dpo.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ __all__ = [
16
+ "PatchDPOTrainer",
17
+ ]
18
+
19
+ try:
20
+ from transformers.utils.notebook import (
21
+ IntervalStrategy,
22
+ NotebookTrainingTracker,
23
+ NotebookProgressCallback,
24
+ )
25
+ HAS_NOTEBOOK = True
26
+ except:
27
+ HAS_NOTEBOOK = False
28
+ pass
29
+ import torch
30
+ from ._utils import torch_compile_options
31
+ import inspect
32
+ import torch.nn as nn
33
+ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
34
+
35
+
36
+ DPOTrainer_metrics = [
37
+ "rewards/chosen",
38
+ "rewards/rejected",
39
+ "rewards/accuracies",
40
+ "rewards/margins",
41
+ "logps/rejected",
42
+ "logps/chosen",
43
+ "logits/rejected",
44
+ "logits/chosen",
45
+ ]
46
+ set_DPOTrainer_metrics = frozenset(DPOTrainer_metrics)
47
+
48
+
49
+ def NotebookProgressCallback_on_train_begin(self, args, state, control, **kwargs):
50
+ self.first_column = "Epoch" if args.eval_strategy == IntervalStrategy.EPOCH else "Step"
51
+ self.training_loss = 0
52
+ self.last_log = 0
53
+ column_names = [self.first_column] + ["Training Loss"]
54
+ if args.eval_strategy != IntervalStrategy.NO:
55
+ column_names.append("Validation Loss")
56
+ column_names += [x.replace("/", " / ") for x in DPOTrainer_metrics]
57
+ self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names)
58
+ pass
59
+
60
+
61
+ def NotebookProgressCallback_on_log(self, args, state, control, logs=None, **kwargs):
62
+ # Only for when there is no evaluation
63
+ if args.eval_strategy == IntervalStrategy.NO and "loss" in logs:
64
+ values = {"Training Loss": logs["loss"]}
65
+ for metric in DPOTrainer_metrics:
66
+ values[metric.replace("/", " / ")] = logs[metric]
67
+ pass
68
+ # First column is necessarily Step since we're not in epoch eval strategy
69
+ values["Step"] = state.global_step
70
+ self.training_tracker.write_line(values)
71
+ pass
72
+ pass
73
+
74
+
75
+ def NotebookTrainingTracker_write_line(self, values):
76
+ """
77
+ Write the values in the inner table.
78
+
79
+ Args:
80
+ values (`Dict[str, float]`): The values to display.
81
+ """
82
+ if self.inner_table is None:
83
+ self.inner_table = [list(values.keys()), list(values.values())]
84
+ else:
85
+ columns = self.inner_table[0]
86
+ new_values = {}
87
+ for key, value in values.items():
88
+ lowered = key.lower()
89
+ if lowered in set_DPOTrainer_metrics:
90
+ new_values[lowered.replace("/", " / ")] = value
91
+ else:
92
+ new_values[key] = value
93
+ pass
94
+ values = new_values
95
+
96
+ self.inner_table[0] = columns
97
+ if len(self.inner_table) > 1:
98
+ last_values = self.inner_table[-1]
99
+ first_column = self.inner_table[0][0]
100
+ if last_values[0] != values[first_column]:
101
+ # write new line
102
+ self.inner_table.append([values[c] if c in values else "No Log" for c in columns])
103
+ else:
104
+ # update last line
105
+ new_values = values
106
+ for c in columns:
107
+ if c not in new_values.keys():
108
+ new_values[c] = last_values[columns.index(c)]
109
+ self.inner_table[-1] = [new_values[c] for c in columns]
110
+ else:
111
+ # Edit for evaluation purposes
112
+ self.inner_table.append([values[c] if c in values else 0 for c in columns])
113
+ pass
114
+ pass
115
+ pass
116
+
117
+
118
+ def PatchDPOTrainer():
119
+ if HAS_NOTEBOOK:
120
+ from transformers.trainer import is_in_notebook
121
+ if is_in_notebook():
122
+ # Patch DPO notebook printing
123
+ NotebookTrainingTracker.write_line = NotebookTrainingTracker_write_line
124
+ from transformers.trainer import DEFAULT_PROGRESS_CALLBACK
125
+ DEFAULT_PROGRESS_CALLBACK.on_train_begin = NotebookProgressCallback_on_train_begin
126
+ DEFAULT_PROGRESS_CALLBACK.on_log = NotebookProgressCallback_on_log
127
+ pass
128
+ pass
129
+ pass
130
+
unsloth-main/unsloth-main/unsloth/models/gemma.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .llama import *
16
+ from ._utils import __version__
17
+ import math
18
+
19
+ try:
20
+ from transformers.models.gemma.modeling_gemma import (
21
+ GemmaAttention,
22
+ GemmaDecoderLayer,
23
+ GemmaModel,
24
+ GemmaForCausalLM,
25
+ GemmaRotaryEmbedding,
26
+ apply_rotary_pos_emb,
27
+ repeat_kv,
28
+ )
29
+ except:
30
+ from packaging.version import Version
31
+ transformers_version = Version(transformers_version)
32
+ if not transformers_version >= Version("4.38"):
33
+ raise ImportError(
34
+ f"Unsloth: Your transformers version of {transformers_version} does not support Gemma.\n"\
35
+ f"The minimum required version is 4.38.\n"\
36
+ f'Try `pip install --upgrade "transformers>=4.38"`\n'\
37
+ f"to obtain the latest transformers build, then restart this session."\
38
+ )
39
+ pass
40
+ pass
41
+
42
+ from transformers.modeling_attn_mask_utils import (
43
+ _prepare_4d_causal_attention_mask_for_sdpa,
44
+ )
45
+ # For Pytorch 2.1.1
46
+ try:
47
+ from transformers.models.gemma.modeling_gemma import (
48
+ GemmaSdpaAttention,
49
+ GemmaFlashAttention2,
50
+ )
51
+ except:
52
+ GemmaSdpaAttention = GemmaAttention
53
+ GemmaFlashAttention2 = GemmaAttention
54
+ pass
55
+
56
+
57
+ torch_nn_functional_gelu = torch.nn.functional.gelu
58
+ def fast_geglu_inference(self, X):
59
+ # gate = self.gate_proj(X)
60
+ # up = self.up_proj(X)
61
+ bsz, _, hd = X.shape
62
+ # mlp_size = self.config.intermediate_size
63
+ # temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0")
64
+
65
+ gate = fast_linear_forward(self.gate_proj, X)#, out = temp[0])
66
+ up = fast_linear_forward(self. up_proj, X)#, out = temp[1])
67
+ gate = torch_nn_functional_gelu(gate, approximate = "tanh")
68
+ gate *= up
69
+
70
+ # X = self.down_proj(gate)
71
+ down = fast_linear_forward(self.down_proj, gate, out = up[:,:,:hd])
72
+ return down
73
+ pass
74
+
75
+
76
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
77
+ def GemmaDecoderLayer_fast_forward(
78
+ self,
79
+ hidden_states: torch.Tensor,
80
+ causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
81
+ attention_mask: Optional[torch.Tensor] = None,
82
+ position_ids: Optional[torch.LongTensor] = None,
83
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
84
+ output_attentions: Optional[bool] = False,
85
+ use_cache: Optional[bool] = False,
86
+ padding_mask: Optional[torch.LongTensor] = None,
87
+ *args, **kwargs,
88
+ ):
89
+ if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
90
+ out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0")
91
+
92
+ # Self Attention
93
+ residual = hidden_states
94
+ hidden_states = fast_rms_layernorm_inference_gemma(self.input_layernorm, hidden_states, out_weight)
95
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
96
+ hidden_states=hidden_states,
97
+ causal_mask=causal_mask,
98
+ attention_mask=attention_mask,
99
+ position_ids=position_ids,
100
+ past_key_value=past_key_value,
101
+ output_attentions=output_attentions,
102
+ use_cache=use_cache,
103
+ padding_mask=padding_mask,
104
+ )
105
+ hidden_states += residual
106
+
107
+ # Fully Connected
108
+ residual = hidden_states
109
+ hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight)
110
+ hidden_states = fast_geglu_inference(self.mlp, hidden_states)
111
+ hidden_states += residual
112
+ else:
113
+ residual = hidden_states
114
+ hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states, gemma = True)
115
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
116
+ hidden_states=hidden_states,
117
+ causal_mask=causal_mask,
118
+ attention_mask=attention_mask,
119
+ position_ids=position_ids,
120
+ past_key_value=past_key_value,
121
+ output_attentions=output_attentions,
122
+ use_cache=use_cache,
123
+ padding_mask=padding_mask,
124
+ )
125
+ hidden_states = residual + hidden_states
126
+
127
+ # Fully Connected
128
+ residual = hidden_states
129
+ hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states, gemma = True)
130
+ hidden_states = self.mlp(hidden_states)
131
+ hidden_states = residual + hidden_states
132
+ pass
133
+
134
+ outputs = (hidden_states,)
135
+ if output_attentions: outputs += (self_attn_weights,)
136
+ if use_cache: outputs += (present_key_value,)
137
+ return outputs
138
+ pass
139
+
140
+
141
+ from math import sqrt as math_sqrt
142
+
143
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
144
+ # @torch.inference_mode
145
+ def GemmaModel_fast_forward_inference(
146
+ self,
147
+ input_ids,
148
+ past_key_values,
149
+ position_ids,
150
+ attention_mask = None,
151
+ ):
152
+ out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0")
153
+ input_ids = input_ids[:,:self.max_seq_length]
154
+ hidden_states = self.model.embed_tokens(input_ids)
155
+ hidden_states = hidden_states.to(self.config.torch_dtype)
156
+ # 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
157
+ # 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
158
+ hidden_states *= torch.tensor(math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype)
159
+
160
+ bsz, q_len, hd = hidden_states.shape
161
+ seq_len = past_key_values[0][0].shape[-2]
162
+ if bsz != 1:
163
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
164
+ attention_mask,
165
+ (bsz, q_len),
166
+ hidden_states,
167
+ seq_len,
168
+ )
169
+ pass
170
+
171
+ next_decoder_cache = []
172
+ for idx, decoder_layer in enumerate(self.model.layers):
173
+ residual = hidden_states
174
+ hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight)
175
+ hidden_states, present_key_value = LlamaAttention_fast_forward_inference(
176
+ decoder_layer.self_attn,
177
+ hidden_states = hidden_states,
178
+ past_key_value = past_key_values[idx],
179
+ position_ids = position_ids,
180
+ attention_mask = attention_mask,
181
+ do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
182
+ )
183
+ hidden_states += residual
184
+
185
+ residual = hidden_states
186
+ hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight)
187
+ hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states)
188
+ hidden_states += residual
189
+
190
+ next_decoder_cache.append(present_key_value)
191
+ pass
192
+ hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weight)
193
+
194
+ return BaseModelOutputWithPast(
195
+ last_hidden_state = hidden_states,
196
+ past_key_values = next_decoder_cache,
197
+ hidden_states = [],
198
+ attentions = [],
199
+ )
200
+ pass
201
+
202
+
203
+ # Follows line by line https://github.com/google-deepmind/gemma/blob/main/gemma/positional_embeddings.py#L45
204
+ # Formulates cos and sin differently from Llama!
205
+ class GemmaFixedRotaryEmbedding(torch.nn.Module):
206
+ # Fixes https://github.com/huggingface/transformers/pull/28837
207
+ # https://github.com/microsoft/DeepSpeed/issues/4932
208
+ # The precision of RoPE buffers is not correct, so we cast to int64.
209
+ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None,
210
+ config = None, # [TODO] Hack to pass in config - need to remove later
211
+ ):
212
+ super().__init__()
213
+ if config is not None: return # [TODO] Hack to pass in config - need to remove later
214
+ self.dim = dim
215
+ self.max_position_embeddings = max_position_embeddings
216
+ self.base = base
217
+ # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
218
+ self.current_rope_size = min(4 * 8192, self.max_position_embeddings)
219
+
220
+ # Build here to make `torch.jit.trace` work.
221
+ self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype())
222
+ pass
223
+
224
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
225
+ # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
226
+ # in FP32. They are applied (multiplied) in FP32 as well.
227
+ self.current_rope_size = seq_len
228
+
229
+ # The difference is we do division explicity instead of t * (1/x) ie we do t/x.
230
+ freq_exponents = (2.0 / self.dim) * (
231
+ torch.arange(self.dim // 2, dtype = torch.int64, device = "cpu").float()
232
+ )
233
+ timescale = self.base**freq_exponents
234
+ positions = torch.arange(self.current_rope_size, device = "cpu", dtype = torch.int64).float()
235
+ radians_new = positions[..., None] / timescale[None, None, :]
236
+ radians_new = radians_new.squeeze(0)
237
+
238
+ emb = torch.cat((radians_new, radians_new), dim = -1)
239
+ # We must do RoPE in float32!
240
+ cos = emb.cos().to(device = "cuda:0", non_blocking = True)#, dtype = dtype)
241
+ sin = emb.sin().to(device = "cuda:0", non_blocking = True)#, dtype = dtype)
242
+ self.register_buffer("cos_cached", cos, persistent = False)
243
+ self.register_buffer("sin_cached", sin, persistent = False)
244
+ pass
245
+
246
+ def forward(self, x, position_ids=None, seq_len=None):
247
+ # x: [bs, num_attention_heads, seq_len, head_size]
248
+ if seq_len > self.current_rope_size:
249
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
250
+
251
+ return (
252
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
253
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
254
+ )
255
+ pass
256
+
257
+ def get_cached(self, seq_len = None):
258
+ return self.cos_cached, self.sin_cached
259
+ pass
260
+
261
+ def extend_rope_embedding(self, x, seq_len):
262
+ if seq_len <= self.current_rope_size: return
263
+ # Iteratively grow by increments of 8192
264
+ self.current_rope_size = math.ceil(seq_len / 8192) * 8192
265
+ self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype)
266
+ pass
267
+ pass
268
+
269
+
270
+ class GemmaFixedLinearScalingRotaryEmbedding(GemmaFixedRotaryEmbedding):
271
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
272
+ # Fixes https://github.com/huggingface/transformers/pull/28837
273
+ # https://github.com/microsoft/DeepSpeed/issues/4932
274
+ # The precision of RoPE buffers is not correct, so we cast to int64.
275
+ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0,
276
+ config = None, # [TODO] Hack to pass in config - need to remove later
277
+ ):
278
+ self.scaling_factor = scaling_factor
279
+ super().__init__(dim = dim, max_position_embeddings = max_position_embeddings, base = base, device = device, config = config)
280
+ pass
281
+
282
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
283
+ # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
284
+ # in FP32. They are applied (multiplied) in FP32 as well.
285
+ self.current_rope_size = seq_len
286
+
287
+ # The difference is we do division explicity instead of t * (1/x) ie we do t/x.
288
+ freq_exponents = (2.0 / self.dim) * (
289
+ torch.arange(self.dim // 2, dtype = torch.int64, device = "cpu").float()
290
+ )
291
+ timescale = self.base**freq_exponents
292
+ positions = torch.arange(self.current_rope_size, device = "cpu", dtype = torch.int64).float()
293
+ positions = positions / self.scaling_factor
294
+ radians_new = positions[..., None] / timescale[None, None, :]
295
+ radians_new = radians_new.squeeze(0)
296
+
297
+ emb = torch.cat((radians_new, radians_new), dim = -1)
298
+ # We must do RoPE in float32!
299
+ cos = emb.cos().to(device = "cuda:0", non_blocking = True)#, dtype = dtype)
300
+ sin = emb.sin().to(device = "cuda:0", non_blocking = True)#, dtype = dtype)
301
+ self.register_buffer("cos_cached", cos, persistent = False)
302
+ self.register_buffer("sin_cached", sin, persistent = False)
303
+ pass
304
+ pass
305
+
306
+
307
+ class FastGemmaModel(FastLlamaModel):
308
+
309
+ @staticmethod
310
+ def pre_patch():
311
+ init_name, function = patch_linear_scaling(
312
+ model_name = "gemma",
313
+ rope_module = GemmaFixedRotaryEmbedding,
314
+ scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding,
315
+ attention_module = GemmaAttention,
316
+ )
317
+ if init_name is not None:
318
+ exec(function, globals())
319
+ GemmaAttention.__init__ = eval(init_name)
320
+ pass
321
+ GemmaAttention .forward = LlamaAttention_fast_forward
322
+ GemmaSdpaAttention .forward = LlamaAttention_fast_forward
323
+ GemmaFlashAttention2.forward = LlamaAttention_fast_forward
324
+ GemmaDecoderLayer .forward = GemmaDecoderLayer_fast_forward
325
+ GemmaModel .forward = LlamaModel_fast_forward
326
+ GemmaForCausalLM .forward = CausalLM_fast_forward(GemmaModel_fast_forward_inference)
327
+ PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
328
+ fix_prepare_inputs_for_generation(GemmaForCausalLM)
329
+
330
+ # Solves https://github.com/unslothai/unsloth/issues/168
331
+ # Static KV Cache was introduced in 4.38.0, causing training to be much slower.
332
+ # Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
333
+ # https://github.com/huggingface/transformers/pull/27931
334
+ # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
335
+ import transformers.models.gemma.modeling_gemma
336
+ transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding = GemmaFixedRotaryEmbedding
337
+ return
338
+ pass
339
+
340
+
341
+ @staticmethod
342
+ def post_patch(model):
343
+ # Patch model for Gemma
344
+ layers = model.model.layers
345
+
346
+ # Torch.compile fails on embedding matrix??
347
+ # Workaround randomnly fixes it for torch versions < 2.2
348
+ model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight)
349
+ model.config.update({"unsloth_version" : __version__})
350
+
351
+ # We also do this for the lm_head
352
+ lm_head = torch.nn.Linear(1, 1, bias = None)
353
+ del lm_head.weight
354
+ lm_head.weight = model.lm_head.weight
355
+ lm_head.in_features = lm_head.weight.shape[1]
356
+ lm_head.out_features = lm_head.weight.shape[0]
357
+ model.lm_head = lm_head
358
+
359
+ # Gemma has tied weights! This means lm_head == embed_tokens
360
+ if model.model.embed_tokens.weight.data_ptr() != model.lm_head.weight.data_ptr():
361
+ lm_head = torch.nn.Linear(1, 1, bias = None)
362
+ del lm_head.weight
363
+ lm_head.weight = model.model.embed_tokens.weight
364
+ lm_head.in_features = lm_head.weight.shape[1]
365
+ lm_head.out_features = lm_head.weight.shape[0]
366
+ model.lm_head = lm_head
367
+ pass
368
+
369
+ # Also patch all dtypes - BnB seems to not allocate the correct type?
370
+ # BnB default dtype seems to be float16!
371
+ correct_dtype = lm_head.weight.dtype
372
+
373
+ for name, module in model.named_modules():
374
+ if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)):
375
+ weight = module.weight
376
+ quant_state = weight.quant_state
377
+
378
+ if type(quant_state) is list:
379
+ # BnB seems to have float16 as default!
380
+ module.weight.quant_state[2] = correct_dtype # Cast to correct dtype
381
+ else:
382
+ # https://github.com/TimDettmers/bitsandbytes/pull/763/files
383
+ quant_state.dtype = correct_dtype
384
+ pass
385
+ pass
386
+ # Downcast RoPE embedding to correct data type
387
+ # RoPE must be done in float32 for Gemma
388
+ # if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")) \
389
+ # and (module.cos_cached.dtype != correct_dtype):
390
+
391
+ # module.cos_cached = module.cos_cached.to(correct_dtype)
392
+ # module.sin_cached = module.sin_cached.to(correct_dtype)
393
+ # pass
394
+ # pass
395
+ pass
396
+
397
+ # Add 1 to weight
398
+ # return output * (1 + self.weight)
399
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L89
400
+ from transformers.models.gemma.modeling_gemma import GemmaRMSNorm
401
+
402
+ # Freeze all parameters except LoRA
403
+ # We do this first since += 1 seems to not be liked by requires_grad = True
404
+ for name, param in model.named_parameters():
405
+ if ".lora_A." in name or ".lora_B." in name:
406
+ param.requires_grad_(True)
407
+ else:
408
+ param.requires_grad_(False)
409
+ pass
410
+
411
+ # Patch RMS Layernorm
412
+ for name, module in model.named_modules():
413
+ if isinstance(module, GemmaRMSNorm):
414
+ # Must be in float32
415
+ # https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L36
416
+ # module = module.to(torch.float32)
417
+ # Leave + 1 to Triton kernel itself
418
+ # module.weight += 1.0 # return output * (1 + self.weight)
419
+ if not hasattr(module, "variance_epsilon"):
420
+ module.variance_epsilon = module.eps # Gemma doesn't use variance_epsilon
421
+ pass
422
+
423
+ # Clear deleted GPU items
424
+ import gc
425
+ for _ in range(3):
426
+ gc.collect()
427
+ torch.cuda.empty_cache()
428
+ return model
429
+ pass
430
+ pass
unsloth-main/unsloth-main/unsloth/models/gemma2.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .llama import *
16
+ from ._utils import __version__
17
+ from .gemma import (
18
+ GemmaFixedRotaryEmbedding,
19
+ GemmaFixedLinearScalingRotaryEmbedding,
20
+ fast_geglu_inference,
21
+ )
22
+ try:
23
+ from transformers.models.gemma2.modeling_gemma2 import (
24
+ Gemma2Attention,
25
+ Gemma2DecoderLayer,
26
+ Gemma2Model,
27
+ Gemma2ForCausalLM,
28
+ Gemma2RotaryEmbedding,
29
+ apply_rotary_pos_emb,
30
+ repeat_kv,
31
+ )
32
+ except:
33
+ from packaging.version import Version
34
+ transformers_version = Version(transformers_version)
35
+ if not transformers_version >= Version("4.42"):
36
+ raise ImportError(
37
+ f"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\n"\
38
+ f"The minimum required version is 4.42.3.\n"\
39
+ f'Try `pip install --upgrade "transformers>=4.42.3"`\n'\
40
+ f"to obtain the latest transformers build, then restart this session."\
41
+ )
42
+ pass
43
+ pass
44
+
45
+ from transformers.modeling_attn_mask_utils import (
46
+ _prepare_4d_causal_attention_mask_for_sdpa,
47
+ )
48
+ # For Pytorch 2.1.1
49
+ try:
50
+ from transformers.models.gemma2.modeling_gemma2 import (
51
+ Gemma2SdpaAttention,
52
+ Gemma2FlashAttention2,
53
+ )
54
+ except:
55
+ Gemma2SdpaAttention = Gemma2Attention
56
+ Gemma2FlashAttention2 = Gemma2Attention
57
+ pass
58
+
59
+ if HAS_FLASH_ATTENTION_SOFTCAPPING:
60
+ from flash_attn import flash_attn_func
61
+
62
+ # [TODO] We must randomnly use torch.compile?
63
+ # I checked the gradients and formulas and I'm sure it's correct.
64
+ # I'm stumped :(
65
+ @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
66
+ def fast_rms_layernorm_gemma2_compiled(layernorm, X, gemma = True):
67
+ old_dtype = X.dtype
68
+ X = X.float()
69
+ X = X * torch.rsqrt(X.square().mean(-1, keepdim = True) + layernorm.eps) * \
70
+ (1.0 + layernorm.weight.float())
71
+ return X.to(old_dtype)
72
+ pass
73
+
74
+
75
+ # Logit softcapping
76
+ def Gemma2Attention_fast_forward(
77
+ self,
78
+ hidden_states: torch.Tensor,
79
+ causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
80
+ attention_mask: Optional[torch.Tensor] = None,
81
+ position_ids: Optional[torch.LongTensor] = None,
82
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
83
+ output_attentions: bool = False,
84
+ use_cache: bool = False,
85
+ padding_mask: Optional[torch.LongTensor] = None,
86
+ *args, **kwargs,
87
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
88
+
89
+ # Clear inference
90
+ if hasattr(self, "paged_attention"):
91
+ del self.paged_attention_K
92
+ del self.paged_attention_V
93
+ del self.paged_attention
94
+ del self.temp_QA
95
+ del self.temp_KV
96
+ del self.RH_Q
97
+ del self.attention
98
+ pass
99
+
100
+ bsz, q_len, _ = hidden_states.size()
101
+
102
+ n_heads = self.num_heads
103
+ n_groups = self.num_key_value_groups
104
+ n_kv_heads = self.num_key_value_heads
105
+ head_dim = self.head_dim
106
+ assert(n_kv_heads * n_groups == n_heads)
107
+
108
+ Q, K, V = self.apply_qkv(self, hidden_states)
109
+ Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
110
+ K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
111
+ V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
112
+
113
+ kv_seq_len = K.shape[-2]
114
+ if past_key_value is not None:
115
+ kv_seq_len += past_key_value[0].shape[-2]
116
+
117
+ if position_ids is None:
118
+ cos = self.rotary_emb.cos_cached
119
+ sin = self.rotary_emb.sin_cached
120
+ Q, K = fast_rope_embedding(Q, K, cos, sin)
121
+ else:
122
+ cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
123
+ Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
124
+ pass
125
+
126
+ if past_key_value is not None:
127
+ K = torch.cat([past_key_value[0], K], dim = 2)
128
+ V = torch.cat([past_key_value[1], V], dim = 2)
129
+ pass
130
+ past_key_value = (K, V) if use_cache else None
131
+
132
+ # Only enable if the attention_mask is True
133
+ has_sliding_window = type(causal_mask) is bool and causal_mask is True
134
+ if HAS_FLASH_ATTENTION_SOFTCAPPING and attention_mask is None:
135
+ window = (-1, -1)
136
+ if has_sliding_window:
137
+ sw = getattr(self.config, "sliding_window", None)
138
+ sw = kv_seq_len if (sw is None or sw == "null") else sw
139
+ window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
140
+ pass
141
+
142
+ # FA uses 1 / sqrt for softmax_scale!
143
+ if not hasattr(self, "_flash_attention_softmax_scale"):
144
+ self._flash_attention_softmax_scale = 1.0 / (self.config.query_pre_attn_scalar**0.5)
145
+ pass
146
+
147
+ Q = Q.transpose(1, 2)
148
+ K = K.transpose(1, 2)
149
+ V = V.transpose(1, 2)
150
+ A = flash_attn_func(
151
+ Q, K, V,
152
+ causal = True,
153
+ softcap = self.config.attn_logit_softcapping,
154
+ softmax_scale = self._flash_attention_softmax_scale,
155
+ window_size = window,
156
+ )
157
+ A = A.reshape(bsz, q_len, n_heads*head_dim)
158
+ else:
159
+ fx = slow_inference_attention_softcapping \
160
+ if "_flag_for_generation" in kwargs else \
161
+ slow_attention_softcapping
162
+ A = fx(Q, K, V, causal_mask, self, bsz, kv_seq_len)
163
+ pass
164
+ A = self.apply_o(self, A)
165
+ return A, None, past_key_value
166
+ pass
167
+
168
+
169
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
170
+ def Gemma2DecoderLayer_fast_forward(
171
+ self,
172
+ hidden_states: torch.Tensor,
173
+ causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
174
+ attention_mask: Optional[torch.Tensor] = None,
175
+ position_ids: Optional[torch.LongTensor] = None,
176
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
177
+ output_attentions: Optional[bool] = False,
178
+ use_cache: Optional[bool] = False,
179
+ padding_mask: Optional[torch.LongTensor] = None,
180
+ *args, **kwargs,
181
+ ):
182
+ if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
183
+ out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0")
184
+
185
+ # Self Attention
186
+ residual = hidden_states
187
+ hidden_states = fast_rms_layernorm_inference_gemma(self.input_layernorm, hidden_states, out_weight)
188
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
189
+ hidden_states=hidden_states,
190
+ causal_mask=causal_mask,
191
+ attention_mask=attention_mask,
192
+ position_ids=position_ids,
193
+ past_key_value=past_key_value,
194
+ output_attentions=output_attentions,
195
+ use_cache=use_cache,
196
+ padding_mask=padding_mask,
197
+ _flag_for_generation=True,
198
+ )
199
+ hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight)
200
+ hidden_states += residual
201
+
202
+ # Fully Connected
203
+ residual = hidden_states
204
+ hidden_states = fast_rms_layernorm_inference_gemma(self. pre_feedforward_layernorm, hidden_states, out_weight)
205
+ hidden_states = fast_geglu_inference(self.mlp, hidden_states)
206
+ hidden_states = fast_rms_layernorm_inference_gemma(self.post_feedforward_layernorm, hidden_states, out_weight)
207
+ hidden_states += residual
208
+ else:
209
+ residual = hidden_states
210
+ hidden_states = fast_rms_layernorm_gemma2_compiled(self.input_layernorm, hidden_states, gemma = True)
211
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
212
+ hidden_states=hidden_states,
213
+ causal_mask=causal_mask,
214
+ attention_mask=attention_mask,
215
+ position_ids=position_ids,
216
+ past_key_value=past_key_value,
217
+ output_attentions=output_attentions,
218
+ use_cache=use_cache,
219
+ padding_mask=padding_mask,
220
+ )
221
+ hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_attention_layernorm, hidden_states, gemma = True)
222
+ hidden_states = residual + hidden_states
223
+
224
+ # Fully Connected
225
+ residual = hidden_states
226
+ hidden_states = fast_rms_layernorm_gemma2_compiled(self. pre_feedforward_layernorm, hidden_states, gemma = True)
227
+ hidden_states = self.mlp(hidden_states)
228
+ hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_feedforward_layernorm, hidden_states, gemma = True)
229
+ hidden_states = residual + hidden_states
230
+ pass
231
+
232
+ outputs = (hidden_states,)
233
+ if output_attentions: outputs += (self_attn_weights,)
234
+ if use_cache: outputs += (present_key_value,)
235
+ return outputs
236
+ pass
237
+
238
+
239
+ from math import sqrt as math_sqrt
240
+ KV_CACHE_INCREMENT = 256 # KV Cache update size
241
+ torch_nn_functional_softmax = torch.nn.functional.softmax
242
+ torch_matmul = torch.matmul
243
+ torch_tanh = torch.tanh
244
+
245
+ def Gemma2Attention_fast_forward_inference(
246
+ self,
247
+ hidden_states: torch.Tensor,
248
+ past_key_value: Optional[Tuple[torch.Tensor]],
249
+ position_ids,
250
+ do_prefill = False,
251
+ attention_mask = None,
252
+ use_sliding_window = False,
253
+ ):
254
+ Xn = hidden_states
255
+ bsz, _, hd = hidden_states.size()
256
+ K1, V1 = past_key_value
257
+ dtype = Xn.dtype
258
+
259
+ n_heads = self.num_heads
260
+ n_groups = self.num_key_value_groups
261
+ n_kv_heads = self.num_key_value_heads
262
+ head_dim = self.head_dim
263
+ attention_size = n_heads*head_dim
264
+ # assert(n_kv_heads * n_groups == n_heads)
265
+ seq_len = K1.shape[-2]
266
+ kv_seq_len = seq_len + 1
267
+
268
+ # Prefill phase
269
+ # if not hasattr(self, "paged_attention"):
270
+ if do_prefill:
271
+ self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda:0")
272
+ self.paged_attention_K = self.paged_attention[:,0]
273
+ self.paged_attention_V = self.paged_attention[:,1]
274
+ self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
275
+ self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
276
+ self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0")
277
+ self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0")
278
+ self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
279
+ # Only for Gemma2
280
+ self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0")
281
+ self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
282
+
283
+ # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
284
+ # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
285
+ # We default to using the config file itself
286
+ # s = self.config.hidden_size // self.config.num_attention_heads
287
+ self.scalar = 1.0 / math_sqrt(self.config.query_pre_attn_scalar)
288
+ # self.scalar = 1.0 / math_sqrt(self.config.hidden_size // self.config.num_attention_heads)
289
+ self.half_head_dim = head_dim // 2
290
+ self. t = self.config.attn_logit_softcapping
291
+ self.reciprocal_t = 1.0 / self.config.attn_logit_softcapping
292
+ elif kv_seq_len >= self.paged_attention.shape[0]:
293
+ self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim))
294
+ self.paged_attention_K = self.paged_attention[:,0]
295
+ self.paged_attention_V = self.paged_attention[:,1]
296
+ self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT))
297
+ pass
298
+
299
+ Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
300
+ Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
301
+ Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
302
+ Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
303
+ Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
304
+ Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
305
+
306
+ # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
307
+ # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
308
+ cos = self.rotary_emb.cos_cached[position_ids].unsqueeze(1)
309
+ sin = self.rotary_emb.sin_cached[position_ids].unsqueeze(1)
310
+ h = self.half_head_dim
311
+
312
+ RH_Q = self.RH_Q
313
+ RH_Q[:,:,:,:h] = Qn[:,:,:,h:]
314
+ RH_Q[:,:,:,h:] = Qn[:,:,:,:h]
315
+ torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
316
+ Qn *= cos
317
+ Qn.addcmul_(RH_Q, sin)
318
+
319
+ RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
320
+ RH_K[:,:,:,:h] = Kn[:,:,:,h:]
321
+ RH_K[:,:,:,h:] = Kn[:,:,:,:h]
322
+ torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
323
+ Kn *= cos
324
+ Kn.addcmul_(RH_K, sin)
325
+
326
+ # New KV cache
327
+ # Kn = torch.cat([K1, Kn], dim = 2)
328
+ # Vn = torch.cat([V1, Vn], dim = 2)
329
+ self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3)
330
+ self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3)
331
+ Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3)
332
+ Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3)
333
+
334
+ # Handle sliding windows
335
+ sliding_window = self.config.sliding_window
336
+ if use_sliding_window and kv_seq_len > sliding_window:
337
+ # From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
338
+ slicing_tokens = 1 - sliding_window
339
+ Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
340
+ Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()
341
+ else:
342
+ Knn, Vnn = Kn, Vn
343
+ pass
344
+
345
+ # Grouped query attention
346
+ _, _, cached_len, _ = Knn.shape
347
+ if n_groups != 1:
348
+ Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
349
+ Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
350
+ Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
351
+ Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
352
+ pass
353
+ # else:
354
+ # Knn, Vnn = Knn, Vnn
355
+ # pass
356
+
357
+ # Attention
358
+ # if bsz == 1:
359
+ Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
360
+ # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
361
+ A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
362
+ # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
363
+
364
+ A *= self.reciprocal_t; torch_tanh(A, out = A); A *= self.t; # Logit softcapping
365
+
366
+ A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
367
+ A = torch_matmul(A, Vnn, out = Qn)
368
+ # else:
369
+ # A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
370
+ # pass
371
+ A = A.transpose(1, 2)
372
+ A = A.reshape(bsz, 1, attention_size)
373
+ A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
374
+ return A, (Kn, Vn)
375
+ pass
376
+
377
+
378
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
379
+ # @torch.inference_mode
380
+ def Gemma2Model_fast_forward_inference(
381
+ self,
382
+ input_ids,
383
+ past_key_values,
384
+ position_ids,
385
+ attention_mask = None,
386
+ ):
387
+ out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0")
388
+ input_ids = input_ids[:,:self.max_seq_length]
389
+ hidden_states = self.model.embed_tokens(input_ids)
390
+ hidden_states = hidden_states.to(self.config.torch_dtype)
391
+ # 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
392
+ # 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
393
+ hidden_states *= torch.tensor(math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype)
394
+
395
+ bsz, q_len, hd = hidden_states.shape
396
+ seq_len = past_key_values[0][0].shape[-2]
397
+ if bsz != 1:
398
+ if HAS_FLASH_ATTENTION_SOFTCAPPING:
399
+ SWA = True
400
+ GA = False
401
+ else:
402
+ SWA = _prepare_4d_causal_attention_mask_for_sdpa(
403
+ attention_mask,
404
+ (bsz, q_len),
405
+ hidden_states,
406
+ seq_len,
407
+ sliding_window = self.config.sliding_window,
408
+ )
409
+ GA = _prepare_4d_causal_attention_mask_for_sdpa(
410
+ attention_mask,
411
+ (bsz, q_len),
412
+ hidden_states,
413
+ seq_len,
414
+ )
415
+ pass
416
+ else:
417
+ SWA = attention_mask
418
+ GA = attention_mask
419
+ pass
420
+ next_decoder_cache = []
421
+ for idx, decoder_layer in enumerate(self.model.layers):
422
+
423
+ use_sliding_window = idx % 2 == 0
424
+
425
+ residual = hidden_states
426
+ hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight)
427
+ hidden_states, present_key_value = Gemma2Attention_fast_forward_inference(
428
+ decoder_layer.self_attn,
429
+ hidden_states = hidden_states,
430
+ past_key_value = past_key_values[idx],
431
+ position_ids = position_ids,
432
+ attention_mask = SWA if use_sliding_window else GA,
433
+ do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
434
+ use_sliding_window = use_sliding_window,
435
+ )
436
+ hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight)
437
+ hidden_states += residual
438
+
439
+ residual = hidden_states
440
+ hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer. pre_feedforward_layernorm, hidden_states, out_weight)
441
+ hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states)
442
+ hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_feedforward_layernorm, hidden_states, out_weight)
443
+ hidden_states += residual
444
+
445
+ next_decoder_cache.append(present_key_value)
446
+ pass
447
+ hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weight)
448
+
449
+ return BaseModelOutputWithPast(
450
+ last_hidden_state = hidden_states,
451
+ past_key_values = next_decoder_cache,
452
+ hidden_states = [],
453
+ attentions = [],
454
+ )
455
+ pass
456
+
457
+
458
+ class FastGemma2Model(FastLlamaModel):
459
+
460
+ @staticmethod
461
+ def pre_patch():
462
+ init_name, function = patch_linear_scaling(
463
+ model_name = "gemma2",
464
+ rope_module = GemmaFixedRotaryEmbedding,
465
+ scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding,
466
+ attention_module = Gemma2Attention,
467
+ )
468
+ if init_name is not None:
469
+ exec(function, globals())
470
+ Gemma2Attention.__init__ = eval(init_name)
471
+ pass
472
+ Gemma2Attention .forward = Gemma2Attention_fast_forward
473
+ Gemma2SdpaAttention .forward = Gemma2Attention_fast_forward
474
+ Gemma2FlashAttention2.forward = Gemma2Attention_fast_forward
475
+ Gemma2DecoderLayer .forward = Gemma2DecoderLayer_fast_forward
476
+ Gemma2Model .forward = LlamaModel_fast_forward
477
+ Gemma2ForCausalLM .forward = CausalLM_fast_forward(Gemma2Model_fast_forward_inference)
478
+ PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward
479
+ fix_prepare_inputs_for_generation(Gemma2ForCausalLM)
480
+
481
+ # Solves https://github.com/unslothai/unsloth/issues/168
482
+ # Static KV Cache was introduced in 4.38.0, causing training to be much slower.
483
+ # Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
484
+ # https://github.com/huggingface/transformers/pull/27931
485
+ # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
486
+ import transformers.models.gemma2.modeling_gemma2
487
+ transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding = GemmaFixedRotaryEmbedding
488
+ return
489
+ pass
490
+
491
+
492
+ @staticmethod
493
+ def post_patch(model):
494
+ # Patch model for Gemma
495
+ layers = model.model.layers
496
+
497
+ # Torch.compile fails on embedding matrix??
498
+ # Workaround randomnly fixes it for torch versions < 2.2
499
+ model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight)
500
+ model.config.update({"unsloth_version" : __version__})
501
+
502
+ # We also do this for the lm_head
503
+ lm_head = torch.nn.Linear(1, 1, bias = None)
504
+ del lm_head.weight
505
+ lm_head.weight = model.lm_head.weight
506
+ lm_head.in_features = lm_head.weight.shape[1]
507
+ lm_head.out_features = lm_head.weight.shape[0]
508
+ model.lm_head = lm_head
509
+
510
+ # Gemma has tied weights! This means lm_head == embed_tokens
511
+ if model.model.embed_tokens.weight.data_ptr() != model.lm_head.weight.data_ptr():
512
+ lm_head = torch.nn.Linear(1, 1, bias = None)
513
+ del lm_head.weight
514
+ lm_head.weight = model.model.embed_tokens.weight
515
+ lm_head.in_features = lm_head.weight.shape[1]
516
+ lm_head.out_features = lm_head.weight.shape[0]
517
+ model.lm_head = lm_head
518
+ pass
519
+
520
+ # Also patch all dtypes - BnB seems to not allocate the correct type?
521
+ # BnB default dtype seems to be float16!
522
+ correct_dtype = lm_head.weight.dtype
523
+
524
+ for name, module in model.named_modules():
525
+ if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)):
526
+ weight = module.weight
527
+ quant_state = weight.quant_state
528
+
529
+ if type(quant_state) is list:
530
+ # BnB seems to have float16 as default!
531
+ module.weight.quant_state[2] = correct_dtype # Cast to correct dtype
532
+ else:
533
+ # https://github.com/TimDettmers/bitsandbytes/pull/763/files
534
+ quant_state.dtype = correct_dtype
535
+ pass
536
+ pass
537
+ # Downcast RoPE embedding to correct data type
538
+ # RoPE must be done in float32 for Gemma
539
+ # if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")) \
540
+ # and (module.cos_cached.dtype != correct_dtype):
541
+
542
+ # module.cos_cached = module.cos_cached.to(correct_dtype)
543
+ # module.sin_cached = module.sin_cached.to(correct_dtype)
544
+ # pass
545
+ # pass
546
+ pass
547
+
548
+ # Add 1 to weight
549
+ # return output * (1 + self.weight)
550
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L89
551
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm
552
+
553
+ # Freeze all parameters except LoRA
554
+ # We do this first since += 1 seems to not be liked by requires_grad = True
555
+ for name, param in model.named_parameters():
556
+ if ".lora_A." in name or ".lora_B." in name:
557
+ param.requires_grad_(True)
558
+ else:
559
+ param.requires_grad_(False)
560
+ pass
561
+
562
+ # Patch RMS Layernorm
563
+ for name, module in model.named_modules():
564
+ if isinstance(module, Gemma2RMSNorm):
565
+ # Must be in float32
566
+ # https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L36
567
+ # module = module.to(torch.float32)
568
+ # Leave + 1 to Triton kernel itself
569
+ # module.weight += 1.0 # return output * (1 + self.weight)
570
+ if not hasattr(module, "variance_epsilon"):
571
+ module.variance_epsilon = module.eps # Gemma doesn't use variance_epsilon
572
+ pass
573
+
574
+ # Clear deleted GPU items
575
+ import gc
576
+ for _ in range(3):
577
+ gc.collect()
578
+ torch.cuda.empty_cache()
579
+ return model
580
+ pass
581
+ pass
unsloth-main/unsloth-main/unsloth/models/llama.py ADDED
The diff for this file is too large to render. See raw diff